backwards_select_analysis.m 8.87 KB
Newer Older
Michaela Olson's avatar
Michaela Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
%% backwards selecvt analysis
% all of the parts of backwards select without the prompts to load or save

% set this as as safety 
if ~exist('apply','var')
    apply = false;
end 

if ~exist('no_TVN','var')
    no_TVN = false; 
end 

%% Start backwards select 
% start off by getting original percet 
disp(strcat(num2str(pct_correct), "% for ", num2str(numvar), " variables"))

% this is what we are looking to beat
baseline_score = pct_correct;

% minumum variable number to run until 
minvar = 22; 

% set keep_going to true to get while loop started
keep_going = true;

% save all of the best variable iterations
overall_vars = struct; 

% use external function to get git info
git_info = getGitInfo();

% can save hash info in overall vars
try
    overall_vars.hash =  git_info.hash; 
catch
    % have a catch in case repository wasn't cloned 
    disp("unable to to save git hash information")
end 

% store the drugs used to create this in overall_vars 
overall_vars.DRUGS = unique(final_data_table.DRUG)';
% hard coding this for now b/c we're having problems 
% remove_after_TVN = true;
if remove_after_TVN
    % not actually using untreated, want to remove from .DRUG for clarity
    all_drugs_from_final = unique(final_data_table.DRUG)';
    unt_loca = find(strcmp(all_drugs_from_final,"Untreated"));
    all_drugs_from_final(unt_loca) = [];
    
    overall_vars.DRUGS = all_drugs_from_final;
end 
% save the workspace used
overall_vars.WORKSPACE = chosen_workspaces;

disp(strcat("distance metric is ", d_metric))

if large_group
    disp("using broader groups")
else
    disp("using finer groups")
end


% use randomize_labels function on the drug column in final_data_table
if with_randomized_labels
    
    % determine the randomized order
    rand_order = randperm(length(final_data_table.DRUG));

    disp("random!")
    overall_vars.RANDOM = true;
end 

% store in the overall_vars if joint profile or not 
if do_joint_profile
    overall_vars.JOINT_PROFILE = true;
else
    overall_vars.JOINT_PROFILE = false;
end 

% once the max value is not going up anymore, we stop this loop
disp("Beginning backwards select....")
disp("value of remove_after_TVN")
disp(remove_after_TVN)

while keep_going
    
    % want to make an empty % error thing to store
    all_pcts = zeros(length(starting_vars),1);
    
    % reset numvar 
    numvar = length(kept_vars);
    
    % get title for structure
    field_title = strcat("vars_",num2str(numvar),"_pct_",num2str(floor(baseline_score)));
    
    % save current number of variables in structure 
    overall_vars.(field_title) = kept_vars;
    
    %%% might also want to create a confusion matrix and save its handle in
    %%% overall_vars? 
    
    %% Now want to loop through, removing one variable at a time from starting_vars and calculating error 
    for p = 1:length(kept_vars)
        % reset to the full list of variables
        starting_vars = kept_vars;
        % get the name of the var removed
        var_removed = starting_vars{p};

        % remove a variable from the list
        starting_vars(p) = [];
        
        % perform TVN with newly selected features
        if no_TVN
            pca_whitened_table = final_data_table; 
        else
            TVN_transform
        end 
        % if applied drug, remove here
        if apply
            
            for qi = addedDrug
                added_drug_1 = qi{1};
                applied_inds = find(strcmp(pca_whitened_table.(choice_extension),added_drug_1));
            % remove!!!
                pca_whitened_table(applied_inds,:) = [];
            end 
 
        end 
        
        % if we want to get rid of the untreated we gotta do it every time
        %% Remove after TVN
        % forcing this true for now but change later
        if remove_after_TVN
            drugs = unique(pca_whitened_table.DRUG)';

            drug_indexes = cell2struct(cell(1,length(drugs)), drugs, 2);

            for i = drugs
                drug = i{1};
                drug_indexes.(drug) = find(strcmp(pca_whitened_table.DRUG, drug)).';
            end

            pca_whitened_table(drug_indexes.Untreated,:) = [];

            drugs = unique(pca_whitened_table.DRUG)';

            drug_indexes = cell2struct(cell(1,length(drugs)), drugs, 2);

            for i = drugs
                drug = i{1};
                drug_indexes.(drug) = find(strcmp(pca_whitened_table.DRUG, drug)).';
            end

        end 
        %% with randomized labels 
        if with_randomized_labels
            % go through each drug and reorder in the random order defined 
            randomized_drugs = pca_whitened_table.DRUG; 
            for w = 1:length(randomized_drugs)
                randomized_drugs(w) = pca_whitened_table.DRUG(rand_order(w));
            end 

            pca_whitened_table.DRUG = randomized_drugs;

        end 
        %% now time for the analysis 
        %find numeric columns of final table
        numeric_final_data_cols = varfun(@isnumeric,pca_whitened_table,'OutputFormat', 'uniform');

        %find column names
        numeric_final_col_names = pca_whitened_table.Properties.VariableNames(numeric_final_data_cols);
        
        if plot_before_tvn
             %find numeric columns of final table
            numeric_final_data_cols = varfun(@isnumeric,final_data_table,'OutputFormat', 'uniform');
            
            numeric_final_col_names = final_data_table.Properties.VariableNames(numeric_final_data_cols);
            
            zero_mean_table = normalize_and_zero_mean(final_data_table,numeric_final_data_cols);

            table_to_zero_mean = final_data_table(:,starting_vars);

            table_to_zero_mean = horzcat(non_numeric,table_to_zero_mean);
            
            numeric_final_data_cols = varfun(@isnumeric,table_to_zero_mean,'OutputFormat', 'uniform');

            %find column names
            numeric_final_col_names = table_to_zero_mean.Properties.VariableNames(numeric_final_data_cols);

            zero_mean_table = normalize_and_zero_mean(table_to_zero_mean,numeric_final_data_cols);
            
            pca_whitened_table = zero_mean_table;
        end 
        % get numeric data
        ndata = pca_whitened_table{:,numeric_final_data_cols};
        
        % non numeric data
        non_numeric = pca_whitened_table(:,~numeric_final_data_cols);
        
        %% perform PCA calculation 
        [coeff,scores,pcvars] = pca(ndata);

        % put data in a table for better organization
        scores_table = horzcat(non_numeric,array2table(scores)); 

        drugs = unique(scores_table.DRUG)';

        % confirm drug_indicies 
        drug_indexes = cell2struct(cell(1,length(drugs)), drugs, 2);

        for i = drugs
            drug = i{1};
            drug_indexes.(drug) = find(strcmp(scores_table.DRUG, drug)).';
        end

        %% perform knn 
        if after_pca
            % if after pca, want to use scores table
            [drug_medians,drugs] = table_median(scores_table,choice_extension);
            overall_vars.AFTER_PCA = true;
        else
            % if not, want to use pca_whitened_table
            [drug_medians,drugs] = table_median(pca_whitened_table,choice_extension);
            overall_vars.AFTER_PCA = false;
        end
        

        % use medians as knn_data
        knn_data = drug_medians;
        knn_drugs = drugs';
        
        % if we're doing control confusion, needs to be like make_full
        if confuse_controls
            knn_data = scores;
            knn_drugs = scores_table.(choice_extension);    
        end 

        % run helper script to create confusion matrix (make sure things are
        % set to not show or else you will have A Time)
        knn_helper

        all_pcts(p) = pct_correct; 
        % now we have pct_error I guess just disp it for now yeeeet
     %  disp(strcat(num2str(pct_correct), "% for ",var_removed," removed"))
    end 
    
    % now want to find the max value in all_pcts
    [max_value] = max(all_pcts);
    
    % want to make sure we get all of the possible max indicies
    max_idx = find(all_pcts == max_value);

    %  if there is more than one location that has this max percent, want to
    %  only take one (and take the one that comes last)
    if length(max_idx) > 1
        max_idx = max_idx(end); 
    end 
    
    % if the max_value of the % succcess is better than the baseline score,
    % we like that and want to go with that new variable set
    % going to change this to 20 vars because generally don't trust
    % anything that low anyway, might as well speed things up
    % max_value >= baseline_score &&
    if length(starting_vars) > minvar 
        
      % want to throw this out from starting_vars and here we go ~again~
        kept_vars(max_idx) = []; 
        baseline_score = max_value;
        disp(strcat(num2str(baseline_score), "% and ", ...
            num2str(length(kept_vars)), " variables"))
        
    else
        keep_going = false;

        numvar = length(kept_vars);
        
    end 
end