%%%% Dennis Hernaus, James Gold, James Waltz & Michael Frank %%%% MPRC, University of Maryland %%%% 2018/2 %%%% Script for model fitting - hyrbid, Q and A/C %%%% Used in paper: Hernaus et al 2018. Biol Psych CNNI. %%%% Based on: Gold et al 2012. Arch General Psychiatry (original script by %%%% A. Collins) %%%% Edited by D. Hernaus %%%% 2018-02-08 %%%% TWBB - Context-dependent RL %%%% group: 1-HV, 2-SZ %%%% model: 1-Q+E, 2-AC+E,3-Mix+E, 4-Mix Block-wise (context)+E function fittingTest_320_DH_TWBB(mo,groupe) global model;model = mo; global DonneesGroupe; global Q W Qvals Wvals sub_count; noms={'TWBB_092017_acq_input_HV.csv','TWBB_092017_acq_input_SZ.csv'}; % input data files nom4={'HV','SZ'}; %group data = csvread(noms{groupe}); % load the group data model_name={'Q_E', 'AC_E', 'Mix_E','Mixblock_E'}; DonneesGroupe=data(:,[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16]); % data file (320 trials) consisting of 1: group (1 or 2) % 2: local subject number, 3: analysis subject number 1-.. % 4: run (block 1 or 2), 5: response (1 left, 2 right) % 6: optimal resp (1 left, 2 right), 7: response for reward (1 left, 2 % right), 8: valid trial (0=no, 1=yes), 9: reaction time, 10: pair no. % (1-8), 11: trial no. (1-320), 12: correct (0=no, 1=yes), 13: rewarded % (0=no, 1=yes), 14: repeat (1), 15: probability level (1-4), 16: Block % (Probabilistic(2) or deterministic(1)) % get subID sujs=data(:,3)'; sujets=unique(sujs); % initialize Fitting values FitVals=[]; %create vars for Q and Wvals Qvals = zeros(8,2); %8 pairs in total Wvals = zeros(8,2); %can set this to diff. vals to reflect early optimism/pessimism bias sub_count=1; for si=sujets % loop through subjects FitVals=[FitVals;fittingSujet(si)]; Qvals(sub_count, 1:8, 1:2) = Q; Wvals(sub_count, 1:8, 1:2) = W; sub_count = sub_count+1 end % save fitting values to mat file save(strcat('Fit_',model_name{model},'_TWBB_320_',nom4{groupe}),'FitVals'); % save model parameters save(strcat('Qvals_',model_name{model},'_TWBB_320_',nom4{groupe}),'Qvals'); % save Q vals save(strcat('Wvals_',model_name{model},'_TWBB_320_',nom4{groupe}),'Wvals'); % save W vals end function Results=fittingSujet(sujet) global DonneesGroupe; global DonneesS; global model; global Q W Qvals Wvals sub_count; % extract subject data DonneesS=DonneesGroupe(DonneesGroupe(:,3)==sujet,:); global option; option = 0; % set up optimization options options=optimset('MaxFunEval',100000,'Display', 'off','algorithm','active-set');%'Display','iter','Display','iter', % set up parameter ranges if model==1 pMin=[0 0 0];%alpha, beta e,psilon pMax=[1 1 1]; elseif model==2 pMin=[0 0 0 0];%alphaC, alphaA, beta, epsilon pMax=[1 1 1 1]; elseif model==3 pMin=[0 0 0 0 0 0];%alphaC, alphaA, OFC-Qlearn, mix, beta, epsilon pMax=[1 1 1 1 1 1]; elseif model==4 pMin=[0 0 0 0 0 0];%alphaC, alphaA, OFC-Qlearn, mix, beta, epsilon (block-wise) pMax=[1 1 1 1 1 1]; end Results=[]; % Select starting points, the more points the longer it takes to find % solutions NStartingPoints = 100; for i=1:NStartingPoints % random starting point within the constraints par0=pMin+rand(1,length(pMin)).*(pMax-pMin); % optimization %[param LLH exitflag output]=fmincon(@modele,par0,[],[],[],[],[],pMin,pMax,[],options); [param LLH exitflag output]=fmincon(@modele,par0,[],[],[],[],pMin,pMax,[],options); Results(i,:) = [param 1-LLH]%1- pour pseudoR2 end % find the optimized solution and the parameters no=find(Results(:,end)==max(Results(:,end))); no=no(1); Results=Results(no,:); par = Results(1:end-1); % input the parameters from last step to get estimation option = 1; [output]=modele(par); % output = pr LLH LLH0 BIC BIC0 AIC AIC0 trials Results=[sujet par output] end function [output]=modele(pa)%pr=modele(pa);% global DonneesS; global option; global model; global Q W V sub_count Qvals Wvals choice; % columns of DonneesS %%%Trial types %1. = 90% rew, 10% neutral %2. = 75% rew, 25% neutral %3. = 75% rew, 25% neutral %4. = 60% rew, 40% neutral %2 pairs per type % set up Params if model==1 alpha=pa(1); % OFC-Qlearn beta=100*pa(2); % inverse temperature e=pa(3); % undirected noise Q=zeros(8,2); % Q-value elseif model==2 alphaC=pa(1); % Critique learning rate alphaA=pa(2); % Actor learning rate mix=0; % mixing parameter beta=100*pa(3); % inverse temperature e=pa(4); % undirected noise V=zeros(8,1); % Critic W=.01*ones(8,2); % Actor Q=zeros(8,2); % Q-value elseif model==3 alphaC=pa(1); % Critique learning rate alphaA=pa(2); % Actor learning rate alphaO=pa(3); % OFC- Q learning rate mix=pa(4); % mixing parameter beta=100*pa(5); % inverse temperature e=pa(6); % undirected noise V=zeros(8,1); % Critique W=.01*ones(8,2); % Acteur Q=zeros(8,2); % Q-value elseif model==4 alphaC=pa(1); % Critique learning rate alphaA=pa(2); % Actor learning rate alphaO=pa(3); % OFC- Q learning rate mix=pa(4); % mixing parameter beta=100*pa(5); % inverse temperature e=pa(6); % undirected noise V=zeros(8,1); % Critique W=.01*ones(8,2); % Acteur Q=zeros(8,2); % Q-value end % Extract Stimulus and Outcomes Input = DonneesS(:,10); % stimulus pair 1-8 Cor = DonneesS(:,12); % correct choice 1/0 Feedback1 = DonneesS(:,13); %whether you got reinforced Block = DonneesS(:,16); % Block (D or P) (this is not "run"!) Valid = DonneesS(:,8); No_res = DonneesS(:,9)~=0; % find valid response (i.e. all responses) % MODEL EACH PAIR SEPARATELY ind = find(Input==1); % find index of 1st 90-10 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end ind = find(Input==2); % find index of 2nd 90-10 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end ind = find(Input==3); % % find index of 1st 75-25-1 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end ind = find(Input==4); % find index of 2nd 75-25-1 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end ind = find(Input==5); % find index of 1st 75-25-2 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end ind = find(Input==6); % find index of 2nd 75-25-2 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end ind = find(Input==7); % find index of 1st 60-40 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end ind = find(Input==8); % find index of 2nd 60-40 for indi=1:length(ind) if Feedback1(ind(indi))==0 Feedback(ind(indi))=0; else Feedback(ind(indi))=.05; end end % initialize log likelihood LLH=0; % start fitting %Standard Q model + undirected noise if model==1 for t = 1:length(Input) % get stim st = Input(t); % get choice act=Cor(t)+1; % compute the softmax probability Qst = beta*Q(st,:); %*(W(st,:)+mix*) % get probability for action taken pr = (1-e)*(exp(Qst(act))./sum(exp(Qst)))+(e*.5); % add to log likelihood LLH = LLH+log(pr); r = Feedback(t); % RL equation Q(st,act) = Q(st,act)+alpha*(r-Q(st,act)); end %Standard AC model + undirected noise elseif model==2 for t = 1:length(Input) % get stim st = Input(t); % get choice act = Cor(t)+1; %Qst = beta*W(st,:); Qst = (W(st,:)+mix*Q(st,:))*beta; %mixing set to 0 here, so no contribution from Q % get probability for action taken pr = (1-e)*(exp(Qst(act))./sum(exp(Qst)))+(e*.5); % add to log likelihood LLH = LLH+log(pr); r = Feedback(t); % RL equation deltaV = r-V(st); %RPE V(st) = V(st)+alphaC*deltaV; %update state W(st,act) = W(st,act)+alphaA*deltaV; %update choice W(st,:) = W(st,:)/sum(abs(W(st,:))); end %Standard Mix model + epsilon elseif model==3 for t = 1:length(Input) % get stim st = Input(t); % get choice act = Cor(t)+1; % get mixed Q and actor value Qst = ((1-mix)*W(st,:)+mix*Q(st,:))*beta; % get probability for action taken pr = (1-e)*(exp(Qst(act))./sum(exp(Qst)))+(e*.5); % add to log likelihood LLH = LLH+log(pr); r = Feedback(t); rO = r;%FeedbackO(t); % critic update deltaV = r-V(st); V(st) = V(st)+alphaC*deltaV; % actor update W(st,act) = W(st,act)+alphaA*deltaV; W(st,:) = W(st,:)/sum(abs(W(st,:)));% normalize weights %Q-learning update Q(st,act) = Q(st,act)+alphaO*(rO-Q(st,act)); end %Mix model wtih context (block-wise) dependent state values (NOT pairwise), % + epsilon elseif model==4 for t = 1:length(Input) % get stim st = Input(t); bl = Block(t); %D or P % get choice act = Cor(t)+1; % get mixed Q and actor value Qst = ((1-mix)*W(st,:)+mix*Q(st,:))*beta; % get probability for action taken pr = (1-e)*(exp(Qst(act))./sum(exp(Qst)))+(e*.5); % add to log likelihood LLH = LLH+log(pr); r = Feedback(t); % every time a rew is experienced rO = r;%FeedbackO(t); % critic update deltaV = r-V(bl); V(bl) = V(bl)+alphaC*deltaV; %block-wise critic update (note "bl", not "st") % actor update W(st,act) = W(st,act)+alphaA*deltaV; % W(st,:) = W(st,:)/sum(abs(W(st,:)));% normalize weights %Q-learning update Q(st,act) = Q(st,act)+alphaO*(rO-Q(st,act)); end end taille=t; % total trial number % likelihood of random model LLH0 = sum(log((1/2)*ones(1,taille))); % pseudoR2 is normalized log-likelihood. pseudoR2 = LLH-LLH0; pseudoR2 = -pseudoR2/LLH0; % to minimize pr=1-pseudoR2; % Bayesian information criterion BIC = -2*LLH+length(pa)*log(t); BIC0 = -2*LLH0; % Akaike information criterion AIC = -2*LLH+length(pa)*2; AIC0 = -2*LLH0; if option == 1 output = [pseudoR2 LLH LLH0 BIC BIC0 AIC AIC0 taille]; else output = pr; end end