Home > NoiseTools > nt_cca_crossvalidate.m

nt_cca_crossvalidate

PURPOSE ^

[AA,BB,RR,SD]=nt_cca_crossvalidate(xx,yy,shifts,doSurrogate) - CCA with cross-validation

SYNOPSIS ^

function [AA,BB,RR,SD]=nt_cca_crossvalidate(xx,yy,shifts,doSurrogate)

DESCRIPTION ^

[AA,BB,RR,SD]=nt_cca_crossvalidate(xx,yy,shifts,doSurrogate) - CCA with cross-validation

  AA, BB: cell arrays of transform matrices
  RR: r scores (2D)
  SD: standard deviation of correlation over non-matching pairs (2D)

  xx,yy: cell arrays of column matrices
  shifts: array of shifts to apply to y relative to x (can be negative)
  doSurrogate: if true estimate sd of correlation over non-matching pairs

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [AA,BB,RR,SD]=nt_cca_crossvalidate(xx,yy,shifts,doSurrogate)
0002 %[AA,BB,RR,SD]=nt_cca_crossvalidate(xx,yy,shifts,doSurrogate) - CCA with cross-validation
0003 %
0004 %  AA, BB: cell arrays of transform matrices
0005 %  RR: r scores (2D)
0006 %  SD: standard deviation of correlation over non-matching pairs (2D)
0007 %
0008 %  xx,yy: cell arrays of column matrices
0009 %  shifts: array of shifts to apply to y relative to x (can be negative)
0010 %  doSurrogate: if true estimate sd of correlation over non-matching pairs
0011 
0012 if nargin<4; doSurrogate=[]; end
0013 if nargin<3; shifts=[0]; end
0014 if nargin<2; error('!'); end
0015 if ~iscell(xx) || ~iscell(yy); error('!'); end
0016 if length(xx) ~= length (yy); error('!'); end
0017 if size(xx{1},1) ~= size(yy{1},1); error('!'); end
0018 
0019 if nargout==0 || nargout==4; doSurrogate=1; end
0020 
0021 %%
0022 % calculate covariance matrices
0023 nTrials=length(xx);
0024 n=size(xx{1},2)+size(yy{1},2);
0025 C=zeros(n,n,length(shifts),nTrials);
0026 disp('Calculate all covariances...');
0027 nt_whoss;
0028 for iTrial=1:nTrials
0029     C(:,:,:,iTrial)=nt_cov_lags(xx{iTrial}, yy{iTrial},shifts);
0030 end
0031 
0032 %%
0033 % calculate leave-one-out CCAs
0034 disp('Calculate CCAs...');
0035 for iOut=1:nTrials
0036     CC=sum(C(:,:,:,setdiff(1:nTrials,iOut)),4); % covariance of all trials except iOut
0037     [A,B,R]=nt_cca([],[],[],CC,size(xx{1},2));  % corresponding CCA
0038     AA{iOut}=A;
0039     BB{iOut}=B;
0040 end
0041 clear C CC
0042 
0043 %%
0044 % calculate leave-one-out correlation coefficients
0045 disp('Calculate cross-correlations...');
0046 for iOut=1:nTrials
0047     iNext=mod(iOut,nTrials)+1; % correlate with next in list
0048     A=AA{iOut};
0049     B=BB{iOut};
0050     for iShift=1:length(shifts)
0051         [x,y]=nt_relshift(xx{iOut},yy{iOut},shifts(iShift));
0052         a=A(:,:,iShift);
0053         b=B(:,:,iShift);
0054         r(:,iShift)=diag( nt_normcol(x*a)' * nt_normcol(y*b )) / size(x,1); 
0055     end
0056     RR(:,:,iOut)=r;
0057     if doSurrogate
0058         for iShift=1:length(shifts)
0059             [x,y]=nt_relshift(xx{iOut},yy{iNext},shifts(iShift));
0060             a=A(:,:,iShift);
0061             b=B(:,:,iShift);
0062             mn=min(size(x,1),size(y,1));
0063             s(:,iShift)=diag( nt_normcol(x(1:mn,:)*a)' * nt_normcol(y(1:mn,:)*b )) / mn; 
0064         end
0065         ss(:,:,iOut)=s;
0066     end
0067 end
0068 if doSurrogate
0069     VAR=(sum(ss.^2,3)-sum(ss,3).^2/nTrials) / (nTrials-1);
0070     SD(:,:)=sqrt(VAR);
0071 end
0072 disp('done');
0073 
0074 %%
0075 % If no output arguments, plot something informative
0076 
0077 if nargout==0
0078     figure(1); clf;
0079     if length(shifts)>1; 
0080         plot(mean(RR,3)'); title('correlation for each CC'); xlabel('shift'); ylabel('correlation');
0081         hold on; 
0082         plot(SD', ':r');
0083         legend('correlation','standard error'); legend boxoff
0084     else
0085         plot(squeeze(mean(RR,3))); title ('correlation for each CC'); xlabel('CC'); ylabel('correlation');
0086         plot(SD', ':r');
0087     end
0088     figure(2); clf;
0089     size(RR)
0090     for k=1:min(4,size(RR,1))
0091         subplot(2,2,k);
0092         [~,idx]=max(mean(RR(k,:,:),3));
0093         [x,y]=nt_relshift(xx{1},yy{1},shifts(idx));
0094         plot([x*A(:,k,idx), y*B(:,k,idx)]);
0095         disp(corr(nt_normcol([x*A(:,k,idx), y*B(:,k,idx)])));
0096         title(['CC ',num2str(k)]); xlabel('sample'); 
0097     end
0098 end
0099

Generated on Mon 27-Feb-2017 15:36:07 by m2html © 2005