0001 function [D,E,R,EXTRA]=nt_cca_mm(x,y,ssize,ldaflag,nccs)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 zscore=1;
0019 lda=1;
0020
0021 if nargin<2; error('!'); end
0022 if nargin<3; ssize=[]; end
0023 if nargin<4||isempty(ldaflag); ldaflag=1; end
0024 if nargin<5; nccs=[]; end
0025
0026 if ssize ~= round(ssize); error('!'); end
0027
0028
0029 nsamples=size(x{1},1);
0030 ntrials=numel(x);
0031 for iTrial=1:ntrials
0032 if size(x{iTrial}) ~= size(y{iTrial}); error('!'); end
0033 nsamples=min(nsamples,size(x{iTrial},1));
0034 end
0035 if isempty(ssize); ssize=nsamples; end
0036 nsamples=ssize*floor(nsamples/ssize);
0037
0038
0039 if nsamples<1; error('!'); end
0040 for iTrial=1:ntrials
0041 x{iTrial}=nt_demean(x{iTrial}(1:nsamples,:));
0042 y{iTrial}=nt_demean(y{iTrial}(1:nsamples,:));
0043 end
0044 nsegments=nsamples/ssize;
0045
0046
0047 if 0
0048
0049 for iTrial=1:ntrials
0050 y{iTrial}=y{1+mod(iTrial+5,ntrials)};
0051
0052 end
0053 end
0054
0055
0056 shifts=[0];
0057 [AA,BB,~]=nt_cca_crossvalidate(x,y,shifts);
0058
0059
0060 if isempty(nccs)
0061 nccs=size(AA{1},2);
0062 else
0063 nccs=min(nccs,size(AA{1},2));
0064 for iTrial=1:ntrials
0065 AA{iTrial}=AA{iTrial}(:,1:nccs);
0066 BB{iTrial}=BB{iTrial}(:,1:nccs);
0067 end
0068
0069 end
0070
0071 DD_match=[];
0072 DD_mismatch=[];
0073 RR=[];
0074 rms_eeg=[];
0075 rms_stim=[];
0076
0077
0078 for iTrial=1:ntrials
0079
0080
0081 The CCA solution (AA, BB) was calculated on the basis of other trials.
0082 We apply it to segments of this trial.
0083
0084
0085
0086 cc_x=nt_mmat(x{iTrial},AA{iTrial});
0087 cc_y=nt_mmat(y{iTrial},BB{iTrial});
0088
0089
0090
0091 S_x=zeros(ssize,nccs,nsegments);
0092 S_y=zeros(ssize,nccs,nsegments);
0093 for iSegment=1:nsegments
0094 start=(iSegment-1)*ssize;
0095 S_x(:,:,iSegment)=cc_x(start+(1:ssize),:);
0096 S_y(:,:,iSegment)=cc_y(start+(1:ssize),:);
0097 end
0098
0099
0100 a=sqrt(mean(S_x.^2,1));
0101 a=permute(a,[2 3 4 1]);
0102 a=a(:,:);
0103 a=mean(a);
0104 b=sqrt(mean(S_y.^2,1));
0105 b=permute(b,[2 3 4 1]);
0106 b=b(:,:);
0107 b=mean(b);
0108 rms_eeg=[rms_eeg;a(:)];
0109 rms_stim=[rms_stim;b(:)];
0110
0111
0112
0113 for iSegment=1:nsegments
0114 if zscore
0115 S_x(:,:,iSegment)=nt_normcol(nt_demean(S_x(:,:,iSegment)));
0116 S_y(:,:,iSegment)=nt_normcol(nt_demean(S_y(:,:,iSegment)));
0117 end
0118 end
0119
0120
0121
0122
0123
0124
0125
0126 D_match=sqrt(mean((S_x-S_y).^2, 1));
0127 D_match=permute(D_match,[2 3 4 1]);
0128 D_match=D_match(:,:)';
0129
0130
0131
0132
0133
0134
0135 S_x=S_x(:,:,:);
0136 S_y=S_y(:,:,:);
0137 D_mismatch=zeros(size(S_x,3),size(S_x,3)-1, nccs);
0138 for iSegment=1:size(S_x,3)
0139 other_segments=setdiff(1:size(S_x,3),iSegment);
0140 if 1
0141 tmp=bsxfun(@minus,S_y(:,:,other_segments),S_x(:,:,iSegment));
0142 else
0143
0144 tmp=bsxfun(@minus,S_x(:,:,other_segments),S_y(:,:,iSegment));
0145 end
0146 d=sqrt(mean(tmp.^2, 1));
0147 D_mismatch(iSegment,:,:)=permute(d,[1 3 2]);
0148 end
0149 D_mismatch=mean(D_mismatch,2);
0150 D_mismatch=permute(D_mismatch,[1 3 2]);
0151
0152
0153 if ldaflag==1
0154
0155 We want to transform distance scores (one per CC) using LDA.
0156
0157 To get the LDA matrix to apply to this trial, we calculate a CCA
0158 solution based on the other trials, and calculate the LDA
0159 solution from that.
0160
0161
0162
0163 other_trials=setdiff(1:ntrials,iTrial);
0164
0165
0166 cc_x2=nt_mmat(x(other_trials),AA{iTrial});
0167 cc_y2=nt_mmat(y(other_trials),BB{iTrial});
0168
0169
0170
0171 S_x=zeros(ssize,nccs,ntrials-1,nsegments);
0172 S_y=zeros(ssize,nccs,ntrials-1,nsegments);
0173 for iTrial2=1:ntrials-1
0174 for iSegment=1:nsegments
0175 start=(iSegment-1)*ssize;
0176
0177 if zscore
0178 S_x(:,:,iTrial2,iSegment)=nt_normcol(nt_demean(cc_x2{iTrial2}(start+(1:ssize),:)));
0179 S_y(:,:,iTrial2,iSegment)=nt_normcol(nt_demean(cc_y2{iTrial2}(start+(1:ssize),:)));
0180 else
0181 S_x(:,:,iTrial2,iSegment)=cc_x2{iTrial2}(start+(1:ssize),:);
0182 S_y(:,:,iTrial2,iSegment)=cc_y2{iTrial2}(start+(1:ssize),:);
0183 end
0184 end
0185 end
0186
0187
0188
0189 S_x=S_x(:,:,:);
0190 S_y=S_y(:,:,:);
0191
0192
0193
0194
0195
0196
0197 D_match2=sqrt(mean((S_x-S_y).^2, 1));
0198 D_match2=permute(D_match2,[2 3 1]);
0199 D_match2=D_match2(:,:)';
0200
0201
0202
0203 D_mismatch2=zeros(size(S_x,3),size(S_x,3)-1, nccs);
0204 for iSegment=1:size(S_x,3)
0205
0206 other_segments=setdiff(1:size(S_x,3),iSegment);
0207 if 0
0208 tmp=bsxfun(@minus,S_y(:,:,other_segments),S_x(:,:,iSegment));
0209 else
0210 tmp=bsxfun(@minus,S_x(:,:,other_segments),S_y(:,:,iSegment));
0211 end
0212 d=sqrt(mean(tmp.^2, 1));
0213
0214 D_mismatch2(iSegment,:,:)=permute(d,[1 3 2]);
0215 end
0216 D_mismatch2=mean(D_mismatch2,2);
0217 D_mismatch2=permute(D_mismatch2,[1 3 2]);
0218
0219 if 0
0220 figure(1); clf
0221 for k=1:4
0222 subplot (2,2,k);
0223 histogram(D_mismatch(:,k)-D_match(:,k), -.5:.01:.5); title(mean(D_mismatch(:,k)-D_match(:,k))/std(D_mismatch(:,k)-D_match(:,k)));
0224 end
0225 end
0226
0227
0228
0229
0230 c0=nt_cov(D_match2)/size(D_mismatch2,1);
0231 c1=nt_cov(D_mismatch2)/size(D_match2,1);
0232 [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0233 if mean(D_match2*todss(:,1), 1)<0; todss=-todss; end
0234
0235 lda_xform=todss;
0236
0237 end
0238
0239 if ldaflag>0 && ldaflag<1
0240 p=ldaflag;
0241 ldaflag=3;
0242 end
0243
0244 switch ldaflag
0245 case 0
0246 D_match=D_match(:,1);
0247 D_mismatch=D_mismatch(:,1);
0248 case 1
0249 D_match=D_match*lda_xform(:,1);
0250 D_mismatch=D_mismatch*lda_xform(:,1);
0251 case 2
0252 D_match=mean(D_match,2);
0253 D_mismatch=mean(D_mismatch,2);
0254 case 3
0255 pp=p.^(0:size(D_match,2)-1);
0256 D_match=mean(D_match.*pp,2);
0257 D_mismatch=mean(D_mismatch.*pp,2);
0258 otherwise
0259 error('!');
0260 end
0261
0262 DD_match=[DD_match; D_match(:)];
0263 DD_mismatch=[DD_mismatch; D_mismatch(:)];
0264
0265 RR(iTrial,:)=diag(corr(cc_x,cc_y));
0266 end
0267
0268 if 0
0269 figure(100); clf;
0270 histogram(DD_mismatch-DD_match, -.5:.05:.5); title(mean(DD_mismatch-DD_match)/std(DD_mismatch-DD_match));
0271 drawnow;
0272 end
0273
0274 D=mean(DD_mismatch-DD_match, 1)/std(DD_mismatch-DD_match);
0275 E=mean(DD_mismatch-DD_match < 0, 1);
0276 R=mean(RR, 1);
0277 EXTRA.DD_mismatch=DD_mismatch;
0278 EXTRA.DD_match=DD_match;
0279 EXTRA.rms_eeg=rms_eeg;
0280 EXTRA.rms_stim=rms_stim;
0281
0282
0283