0001 function [D,E,R,EXTRA]=nt_cca_mm(x,y,ssize,ldaflag,nccs)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 zscore=1;
0019 lda=1;
0020
0021 if nargin<2; error('!'); end
0022 if nargin<3; ssize=[]; end
0023 if nargin<4||isempty(ldaflag); ldaflag=1; end
0024 if nargin<5; nccs=[]; end
0025
0026 if ssize ~= round(ssize); error('!'); end
0027
0028
0029 nsamples=size(x{1},1);
0030 ntrials=numel(x);
0031 for iTrial=1:ntrials
0032 if size(x{iTrial}) ~= size(y{iTrial}); error('!'); end
0033 nsamples=min(nsamples,size(x{iTrial},1));
0034 end
0035 if isempty(ssize); ssize=nsamples; end
0036 nsamples=ssize*floor(nsamples/ssize);
0037
0038
0039 if nsamples<1; error('!'); end
0040 for iTrial=1:ntrials
0041 x{iTrial}=nt_demean(x{iTrial}(1:nsamples,:));
0042 y{iTrial}=nt_demean(y{iTrial}(1:nsamples,:));
0043 end
0044 nsegments=nsamples/ssize;
0045
0046
0047 if 0
0048
0049 for iTrial=1:ntrials
0050 y{iTrial}=y{1+mod(iTrial+5,ntrials)};
0051
0052 end
0053 end
0054
0055
0056 [~,~,~,EXTRA]=nt_cca_mm(x,y,ssize,ldaflag,nccs);
0057
0058
0059 d=EXTRA.DD_mismatch(:)-EXTRA.DD_match(:);
0060 d=sort(d);
0061 PROP=0.1;
0062 thresh=d(round(PROP*numel(d)));
0063
0064 DD_match=reshape(EXTRA.DD_match, [size(EXTRA.DD_match,1)/ntrials, ntrials]);
0065 DD_mismatch=reshape(EXTRA.DD_mismatch, [size(EXTRA.DD_mismatch,1)/ntrials, ntrials]);
0066 for iTrial=1:numel(x)
0067
0068
0069 iBad=find(DD_mismatch(:,iTrial)-DD_match(:,iTrial)<thresh);
0070
0071 disp(iBad);
0072
0073
0074 xx=reshape(x{iTrial}, [ssize,size(x{iTrial},1)/ssize, size(x{iTrial},2)]);
0075 xx(:,iBad,:)=[];
0076 xx=reshape(xx,[size(xx,1)*size(xx,2),size(xx,3)]);
0077 x2{iTrial}=xx;
0078
0079 yy=reshape(y{iTrial},[ssize,size(y{iTrial},1)/ssize, size(y{iTrial},2)]);
0080 yy(:,iBad,:)=[];
0081 yy=reshape(yy,[size(yy,1)*size(yy,2),size(yy,3)]);
0082 y2{iTrial}=yy;
0083 end
0084
0085
0086 shifts=[0];
0087 [AA,BB,~]=nt_cca_crossvalidate(x2,y2,shifts);
0088
0089
0090 if isempty(nccs)
0091 nccs=size(AA{1},2);
0092 else
0093 nccs=min(nccs,size(AA{1},2));
0094 for iTrial=1:ntrials
0095 AA{iTrial}=AA{iTrial}(:,1:nccs);
0096 BB{iTrial}=BB{iTrial}(:,1:nccs);
0097 end
0098
0099 end
0100
0101 DD_match=[];
0102 DD_mismatch=[];
0103 RR=[];
0104 rms_eeg=[];
0105 rms_stim=[];
0106
0107
0108 for iTrial=1:ntrials
0109
0110
0111 The CCA solution (AA, BB) was calculated on the basis of other trials.
0112 We apply it to segments of this trial.
0113
0114
0115
0116 cc_x=nt_mmat(x{iTrial},AA{iTrial});
0117 cc_y=nt_mmat(y{iTrial},BB{iTrial});
0118
0119
0120
0121 S_x=zeros(ssize,nccs,nsegments);
0122 S_y=zeros(ssize,nccs,nsegments);
0123 for iSegment=1:nsegments
0124 start=(iSegment-1)*ssize;
0125 S_x(:,:,iSegment)=cc_x(start+(1:ssize),:);
0126 S_y(:,:,iSegment)=cc_y(start+(1:ssize),:);
0127 end
0128
0129
0130 a=sqrt(mean(S_x.^2,1));
0131 a=permute(a,[2 3 4 1]);
0132 a=a(:,:);
0133 a=mean(a);
0134 b=sqrt(mean(S_y.^2,1));
0135 b=permute(b,[2 3 4 1]);
0136 b=b(:,:);
0137 b=mean(b);
0138 rms_eeg=[rms_eeg;a(:)];
0139 rms_stim=[rms_stim;b(:)];
0140
0141
0142
0143 for iSegment=1:nsegments
0144 if zscore
0145 S_x(:,:,iSegment)=nt_normcol(nt_demean(S_x(:,:,iSegment)));
0146 S_y(:,:,iSegment)=nt_normcol(nt_demean(S_y(:,:,iSegment)));
0147 end
0148 end
0149
0150
0151
0152
0153
0154
0155
0156 D_match=sqrt(mean((S_x-S_y).^2, 1));
0157 D_match=permute(D_match,[2 3 4 1]);
0158 D_match=D_match(:,:)';
0159
0160
0161
0162
0163
0164
0165 S_x=S_x(:,:,:);
0166 S_y=S_y(:,:,:);
0167 D_mismatch=zeros(size(S_x,3),size(S_x,3)-1, nccs);
0168 for iSegment=1:size(S_x,3)
0169 other_segments=setdiff(1:size(S_x,3),iSegment);
0170 if 1
0171 tmp=bsxfun(@minus,S_y(:,:,other_segments),S_x(:,:,iSegment));
0172 else
0173
0174 tmp=bsxfun(@minus,S_x(:,:,other_segments),S_y(:,:,iSegment));
0175 end
0176 d=sqrt(mean(tmp.^2, 1));
0177 D_mismatch(iSegment,:,:)=permute(d,[1 3 2]);
0178 end
0179 D_mismatch=mean(D_mismatch,2);
0180 D_mismatch=permute(D_mismatch,[1 3 2]);
0181
0182
0183 if ldaflag==1
0184
0185 We want to transform distance scores (one per CC) using LDA.
0186
0187 To get the LDA matrix to apply to this trial, we calculate a CCA
0188 solution based on the other trials, and calculate the LDA
0189 solution from that.
0190
0191
0192
0193 other_trials=setdiff(1:ntrials,iTrial);
0194
0195
0196 cc_x2=nt_mmat(x(other_trials),AA{iTrial});
0197 cc_y2=nt_mmat(y(other_trials),BB{iTrial});
0198
0199
0200
0201 S_x=zeros(ssize,nccs,ntrials-1,nsegments);
0202 S_y=zeros(ssize,nccs,ntrials-1,nsegments);
0203 for iTrial2=1:ntrials-1
0204 for iSegment=1:nsegments
0205 start=(iSegment-1)*ssize;
0206
0207 if zscore
0208 S_x(:,:,iTrial2,iSegment)=nt_normcol(nt_demean(cc_x2{iTrial2}(start+(1:ssize),:)));
0209 S_y(:,:,iTrial2,iSegment)=nt_normcol(nt_demean(cc_y2{iTrial2}(start+(1:ssize),:)));
0210 else
0211 S_x(:,:,iTrial2,iSegment)=cc_x2{iTrial2}(start+(1:ssize),:);
0212 S_y(:,:,iTrial2,iSegment)=cc_y2{iTrial2}(start+(1:ssize),:);
0213 end
0214 end
0215 end
0216
0217
0218
0219 S_x=S_x(:,:,:);
0220 S_y=S_y(:,:,:);
0221
0222
0223
0224
0225
0226
0227 D_match2=sqrt(mean((S_x-S_y).^2, 1));
0228 D_match2=permute(D_match2,[2 3 1]);
0229 D_match2=D_match2(:,:)';
0230
0231
0232
0233 D_mismatch2=zeros(size(S_x,3),size(S_x,3)-1, nccs);
0234 for iSegment=1:size(S_x,3)
0235
0236 other_segments=setdiff(1:size(S_x,3),iSegment);
0237 if 0
0238 tmp=bsxfun(@minus,S_y(:,:,other_segments),S_x(:,:,iSegment));
0239 else
0240 tmp=bsxfun(@minus,S_x(:,:,other_segments),S_y(:,:,iSegment));
0241 end
0242 d=sqrt(mean(tmp.^2, 1));
0243
0244 D_mismatch2(iSegment,:,:)=permute(d,[1 3 2]);
0245 end
0246 D_mismatch2=mean(D_mismatch2,2);
0247 D_mismatch2=permute(D_mismatch2,[1 3 2]);
0248
0249 if 0
0250 figure(1); clf
0251 for k=1:4
0252 subplot (2,2,k);
0253 histogram(D_mismatch(:,k)-D_match(:,k), -.5:.01:.5); title(mean(D_mismatch(:,k)-D_match(:,k))/std(D_mismatch(:,k)-D_match(:,k)));
0254 end
0255 end
0256
0257
0258
0259
0260 c0=nt_cov(D_match2)/size(D_mismatch2,1);
0261 c1=nt_cov(D_mismatch2)/size(D_match2,1);
0262 [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0263 if mean(D_match2*todss(:,1), 1)<0; todss=-todss; end
0264
0265 lda_xform=todss;
0266
0267 end
0268
0269 if ldaflag>0 && ldaflag<1
0270 p=ldaflag;
0271 ldaflag=3;
0272 end
0273
0274 switch ldaflag
0275 case 0
0276 D_match=D_match(:,1);
0277 D_mismatch=D_mismatch(:,1);
0278 case 1
0279 D_match=D_match*lda_xform(:,1);
0280 D_mismatch=D_mismatch*lda_xform(:,1);
0281 case 2
0282 D_match=mean(D_match,2);
0283 D_mismatch=mean(D_mismatch,2);
0284 case 3
0285 pp=p.^(0:size(D_match,2)-1);
0286 D_match=mean(D_match.*pp,2);
0287 D_mismatch=mean(D_mismatch.*pp,2);
0288 otherwise
0289 error('!');
0290 end
0291
0292 DD_match=[DD_match; D_match(:)];
0293 DD_mismatch=[DD_mismatch; D_mismatch(:)];
0294
0295 RR(iTrial,:)=diag(corr(cc_x,cc_y));
0296 end
0297
0298 if 0
0299 figure(100); clf;
0300 histogram(DD_mismatch-DD_match, -.5:.05:.5); title(mean(DD_mismatch-DD_match)/std(DD_mismatch-DD_match));
0301 drawnow;
0302 end
0303
0304 D=mean(DD_mismatch-DD_match, 1)/std(DD_mismatch-DD_match);
0305 E=mean(DD_mismatch-DD_match < 0, 1);
0306 R=mean(RR, 1);
0307 EXTRA.DD_mismatch=DD_mismatch;
0308 EXTRA.DD_match=DD_match;
0309 EXTRA.rms_eeg=rms_eeg;
0310 EXTRA.rms_stim=rms_stim;
0311
0312
0313