Home > NoiseTools > nt_cca_crossvalidate.m

nt_cca_crossvalidate

PURPOSE ^

[AA,BB,RR,iBest]=nt_cca_crossvalidate(xx,yy,shifts,ncomp,A0,B0) - CCA with cross-validation

SYNOPSIS ^

function [AA,BB,RR,iBest]=nt_cca_crossvalidate(xx,yy,shifts,ncomp,A0,B0)

DESCRIPTION ^

[AA,BB,RR,iBest]=nt_cca_crossvalidate(xx,yy,shifts,ncomp,A0,B0) - CCA with cross-validation

  AA, BB: cell arrays of transform matrices
  RR: r scores (3D) for all components, shifts and trials
  iBest: index of best shift

  xx,yy: cell arrays of column matrices
  shifts: array of shifts to apply to y relative to x (can be negative)
  ncomp: number of components to consider for iBest [default: all]
  A0,B0: if present, use these CCA transform matrices 

  Plot correlation re shifts for matching trials
    plot(shifts, mean(RR,3)');
  Plot mean correlation re shifts for mismatched trials
    plot(shifts, mean(mean(RR,4),3)');

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [AA,BB,RR,iBest]=nt_cca_crossvalidate(xx,yy,shifts,ncomp,A0,B0)
0002 %[AA,BB,RR,iBest]=nt_cca_crossvalidate(xx,yy,shifts,ncomp,A0,B0) - CCA with cross-validation
0003 %
0004 %  AA, BB: cell arrays of transform matrices
0005 %  RR: r scores (3D) for all components, shifts and trials
0006 %  iBest: index of best shift
0007 %
0008 %  xx,yy: cell arrays of column matrices
0009 %  shifts: array of shifts to apply to y relative to x (can be negative)
0010 %  ncomp: number of components to consider for iBest [default: all]
0011 %  A0,B0: if present, use these CCA transform matrices
0012 %
0013 %  Plot correlation re shifts for matching trials
0014 %    plot(shifts, mean(RR,3)');
0015 %  Plot mean correlation re shifts for mismatched trials
0016 %    plot(shifts, mean(mean(RR,4),3)');
0017 
0018 if nargin<5
0019     A0=[]; B0=[]; 
0020 end
0021 if nargin<4; ncomp=[]; end
0022 if nargin<3 || isempty (shifts); shifts=[0]; end
0023 if nargin<2; error('!'); end
0024 if ~iscell(xx) || ~iscell(yy); error('!'); end
0025 if length(xx) ~= length (yy); error('!'); end
0026 %if size(xx{1},1) ~= size(yy{1},1); error('!'); end
0027 if size(xx{1},1) ~= size(yy{1},1); 
0028     for iTrial=1:numel(xx);
0029         tmp=min(size(xx{iTrial},1),size(yy{iTrial},1));
0030         xx{iTrial}=xx{iTrial}(1:tmp,:);
0031         yy{iTrial}=yy{iTrial}(1:tmp,:);
0032     end
0033 end
0034 
0035 nTrials=length(xx);
0036 
0037 if isempty(A0)
0038     % calculate covariance matrices
0039     n=size(xx{1},2)+size(yy{1},2);
0040     C=zeros(n,n,length(shifts),nTrials);
0041     disp('Calculate all covariances...'); tic;
0042     nt_whoss;
0043     for iTrial=1:nTrials
0044         C(:,:,:,iTrial)=nt_cov_lags(xx{iTrial}, yy{iTrial},shifts);
0045     end
0046 
0047     % calculate leave-one-out CCAs
0048     disp('Calculate CCAs...'); tic;
0049     for iTrial=1:nTrials
0050         CC=sum(C(:,:,:,setdiff(1:nTrials,iTrial)),4); % covariance of all trials except iOut
0051         [A,B,R]=nt_cca([],[],[],CC,size(xx{1},2));  % CCA to apply to that trial (trained on others)
0052         AA{iTrial}=A;
0053         BB{iTrial}=B;
0054     end
0055     clear C CC
0056     toc;
0057 else
0058     % set to given values
0059     for iTrial=1:nTrials
0060         AA{iTrial}=A0;
0061         BB{iTrial}=B0;
0062     end
0063 end
0064 
0065 %%
0066 % calculate leave-one-out correlation coefficients
0067 disp('Calculate cross-correlations...'); tic;
0068 for iShift=1:length(shifts)
0069     xxx={}; yyy={};
0070     % shift, trim to same length, convert to CCs, normalize
0071     for iTrial=1:nTrials
0072         [xxx{iTrial},yyy{iTrial}]=nt_relshift(xx{iTrial},yy{iTrial},shifts(iShift));
0073         xxx{iTrial}=nt_normcol( nt_demean( nt_mmat(xxx{iTrial},AA{iTrial}(:,:,iShift)) ) );
0074         yyy{iTrial}=nt_normcol( nt_demean( nt_mmat(yyy{iTrial},BB{iTrial}(:,:,iShift)) ) );
0075     end
0076     for iTrial=1:nTrials
0077         x=xxx{iTrial};
0078         y=yyy{iTrial};
0079         RR(:,iShift,iTrial)=diag(x'*y) / size(x,1);
0080     end
0081 end
0082 toc;
0083 
0084 if isempty(ncomp); ncomp=size(RR,1); end
0085 [~,iBest]=max(mean(mean(RR(1:ncomp,:,:),3),1)'); 
0086 
0087 disp('done');
0088

Generated on Sat 29-Apr-2023 17:15:46 by m2html © 2005