function crossvalidation_forward_svm_performance_with_component_main % Main function % Multivariate Pattern Recognition for Neuroimaging Toolbox (mPReNT) is a % MATLAB toolbox based on the multivariate pattern recognition techniques % for the discriminative analysis of neuroimaging data. To identify the % informative patterns that distinguish group differences at large-scale % level, we parameterize a series of support vector machine (SVM) models % in conjunction with a (simplified) forward component selection technique % within a nested cross-validation procedure. % In mPReNT, brain networks/patterns jointly at an individual level are % used as bases for a linear subspace on Grassmann manifold to facilitates % a comprehensive characterization of neuroimaging data, and SVM % classification models can be used to discriminate between groups of % subjects with status-specific classification scores. % Writtern by Yong Fan, Rixing Jing. clc;clear; %% read data addpath(genpath('*/Matlab_software/libsvm-3.11')); % location of libsvm toolbox icadata_root= '/*/'; % location of the processed data, % i.e. independent components of the subjects comp_to_be_removed=[]; % remove some components before classification starting_comp_list=[]; % identify some components with priori knowledeg [comp,label]=get_comp(icadata_root,comp_to_be_removed); % Get the component matrix of all subjects and their labels [sub_num,comp_num,voxel_num]=size(comp); % component matrix = sub_num กม comp_num กม voxel_num validation_num=100; test_num=10; data_id=Segment_set(label,test_num,validation_num); % Randomly repeated hold-out cross-validation procedure % data_id is a structural data, including the information of % training set and testing set % data_id.train = validation_num กม train_index % data_id.test = validation_num กม test_index %% Build classification models for i=1:validation_num train_id=data_id.train(i,:); valid_id=data_id.test(i,:); train_com=comp(train_id,:,:); train_label=label(train_id); valid_com=comp(valid_id,:,:); valid_label=label(valid_id); disp(['Loop ' num2str(i)]); %training a model nn=length(train_label); tmodel_id=zeros(1,nn); for j=1:nn % Leave-one-out cross-validation ttrain_id=[1:j-1 j+1:nn]; ttest_id=j; ttrain_com=train_com(ttrain_id,:,:); ttrain_label=train_label(ttrain_id); ttest_com=train_com(ttest_id,:,:); ttest_label=train_label(ttest_id); %%%%%%%%%%%%%%%%%%%% % By default,auroc is the main measurement to estimate the performance of the classifiers. % If you want classification rate as the main measurement, svm_manifold.m should be used in forward_svm. [trate{j},tmaroc{j},tsigma{j}, tgama{j}, tc{j}, tin_list{j}, tmodel{j}]=forward_svm(ttrain_com, ttrain_label, starting_comp_list); % if the number of the components is too large, we can use the simplified forward selection method. [trate{j},tmaroc{j},tsigma{j}, tgama{j}, tc{j}, tin_list{j}, tmodel{j}]=forward_svm_simple(ttrain_com, ttrain_label, starting_comp_list); %%%%%%%%%%%%%%%%%%%% tidx=find(tmaroc{j}==max(tmaroc{j})); if(length(tidx)>1) ttrate=trate{j}(tidx); idx1=find(ttrate==max(ttrate)); tidx=tidx(idx1(1)); end tin_list_selected{j}=tin_list{j}{tidx+1}; tmodel_id(j)=tidx; disp(['AUC is ' num2str(tmaroc{j}(tmodel_id(j))) ' Acc is ' num2str(trate{j}(tmodel_id(j))) ... ' Optimal sigma is ' num2str(tsigma{j}(tmodel_id(j))) ' c is ' num2str(tc{j}(tmodel_id(j))) ]); disp(' list of selection is '); disp(tin_list_selected{j}); [tpred_label(j),tpred_value(j),tind_value{j}]=svm_manifold_predict(ttest_com, ttrain_com, ttrain_label, tin_list{j}{tidx+1}, tsigma{j}(tidx), tgama{j}(tidx), tmodel{j}{tidx}); if tpred_label(j)==ttest_label disp(' Test result is right'); else disp(' Test result is Wrong'); end [pred_label(i,j,:),pred_value(i,j,:),ind_value{i}{j}]=svm_manifold_predict(valid_com, ttrain_com, ttrain_label, tin_list{j}{tidx+1}, tsigma{j}(tidx), tgama{j}(tidx), tmodel{j}{tidx}); acc(i,j)=length(find(valid_label(:)-squeeze(pred_label(i,j,:))==0))/length(valid_label); disp(['ACC = ' num2str(acc(i,j))]); disp('===================================================================='); end Tpred_label{i}=tpred_label; Tpred_value{i}=tpred_value; Tind_value{i}=tind_value; Tin_list_selected{i}=tin_list_selected; Tin_list{i}=tin_list; Trate{i}=trate; Tmaroc{i}=tmaroc; Tsigma{i}=tsigma; Tgama{i}=tgama; Tc{i}=tc; Tmodel{i}=tmodel; Tlabel{i}=train_label; Tmodel_id{i}=tmodel_id; end