Home > NoiseTools > nt_cca.m

nt_cca

PURPOSE ^

[A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag) - canonical correlation

SYNOPSIS ^

function [A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag)

DESCRIPTION ^

[A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag) - canonical correlation

  A, B: transform matrices
  R: r scores

  x,y: column matrices
  shifts: positive lag means y delayed relative to x
  C: covariance matrix of [x, y]
  m: number of columns of x
  thresh: discard PCs below this 
  demeanflag: if true remove means [default: true]

  Usage 1:
   [A,B,R]=nt_cca(x,y); % CCA of x, y

  Usage 2: 
   [A,B,R]=nt_cca(x,y,shifts); % CCA of x, y for each value of shifts.
   A positive shift indicates that y is delayed relative to x.

  Usage 3:
   C=[x,y]'*[x,y]; % covariance
   [A,B,R]=nt_cca([],[],[],C,size(x,2)); % CCA of x,y

 Use the third form to handle multiple files or large data
 (covariance C can be calculated chunk-by-chunk). 

 C can be 3-D, which case CCA is derived independently from each page.

 Warning: means of x and y are NOT removed.
 Warning: A, B scaled so that (x*A)^2 and (y*B)^2 are identity matrices (differs from canoncorr).

 See nt_cov_lags, nt_relshift, nt_cov, nt_pca.

 NoiseTools.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag)
0002 %[A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag) - canonical correlation
0003 %
0004 %  A, B: transform matrices
0005 %  R: r scores
0006 %
0007 %  x,y: column matrices
0008 %  shifts: positive lag means y delayed relative to x
0009 %  C: covariance matrix of [x, y]
0010 %  m: number of columns of x
0011 %  thresh: discard PCs below this
0012 %  demeanflag: if true remove means [default: true]
0013 %
0014 %  Usage 1:
0015 %   [A,B,R]=nt_cca(x,y); % CCA of x, y
0016 %
0017 %  Usage 2:
0018 %   [A,B,R]=nt_cca(x,y,shifts); % CCA of x, y for each value of shifts.
0019 %   A positive shift indicates that y is delayed relative to x.
0020 %
0021 %  Usage 3:
0022 %   C=[x,y]'*[x,y]; % covariance
0023 %   [A,B,R]=nt_cca([],[],[],C,size(x,2)); % CCA of x,y
0024 %
0025 % Use the third form to handle multiple files or large data
0026 % (covariance C can be calculated chunk-by-chunk).
0027 %
0028 % C can be 3-D, which case CCA is derived independently from each page.
0029 %
0030 % Warning: means of x and y are NOT removed.
0031 % Warning: A, B scaled so that (x*A)^2 and (y*B)^2 are identity matrices (differs from canoncorr).
0032 %
0033 % See nt_cov_lags, nt_relshift, nt_cov, nt_pca.
0034 %
0035 % NoiseTools.
0036 
0037 nt_greetings; 
0038 
0039 if nargin<7||isempty(nodemeanflag); demeanflag=1; end
0040 
0041 if ~exist('thresh','var');
0042     thresh=10.^-12; 
0043 end
0044 
0045 if exist('x','var') && ~isempty(x)
0046     
0047     % check parameters
0048     if ~exist('y','var'); error('!'); end
0049     if ~exist('shifts','var')||isempty('shifts'); shifts=[0]; end
0050     
0051     % trim x and y to same number of samples
0052     if iscell(x) && size(x{1},1) ~= size(y{1},1) 
0053         for iTrial=1:numel(x)
0054             tmp=min(size(x{iTrial},1),size(y{iTrial},1));
0055             x{iTrial}=x{iTrial}(1:tmp,:);
0056             y{iTrial}=y{iTrial}(1:tmp,:);
0057         end
0058     else
0059         tmp=min(size(x,1),size(y,1));
0060         x=x(1:tmp,:,:);
0061         y=y(1:tmp,:,:);
0062     end
0063 
0064     % Calculate covariance of [x,y]
0065     if numel(shifts)==1 && shifts==0 && isnumeric(x) && ndims(x)==2; 
0066         if demeanflag
0067             x=nt_demean(x);
0068             y=nt_demean(y);
0069         end
0070         C=[x,y]'*[x,y]; % simple case
0071         m=size(x,2); 
0072     else        
0073         [C,~,m]=nt_cov_lags(x,y,shifts,demeanflag); % lags, multiple trials, etc.
0074     end
0075         
0076     [A,B,R]=nt_cca([],[],[],C,m,thresh);
0077     
0078     if nargout==0 
0079         % plot something nice
0080         if length(shifts)>1;
0081             figure(1); clf;
0082             plot(R'); title('correlation for each CC'); xlabel('lag'); ylabel('correlation');
0083         end
0084      end
0085     return
0086 end % else keep going
0087 
0088 % check that we are called as nt_cca([],[],[],C,m,thresh)
0089 if ~isempty(x) || ~isempty(y) || ~isempty(shifts)  ; error('!'); end
0090 if ~exist('C','var') || isempty(C) ; error('!'); end
0091 if ~exist('m','var'); error('!'); end
0092 if size(C,1)~=size(C,2); error('!'); end
0093 if ndims(C)>3; error('!'); end
0094 
0095 if ndims(C) == 3
0096     % covariance is 3D: do a separate CCA for each page
0097     N=min(m,size(C,1)-m); % note that for some pages there may be fewer than N CCs
0098     A=zeros(m,N,size(C,3));
0099     B=zeros(size(C,1)-m,N,size(C,3));
0100     R=zeros(N,size(C,3));
0101     for k=1:size(C,3);
0102         [AA,BB,RR]=nt_cca([],[],[],C(:,:,k),m);
0103         A(1:size(AA,1),1:size(AA,2),k)=AA;
0104         B(1:size(BB,1),1:size(BB,2),k)=BB;
0105         R(1:size(RR,2),k)=RR;
0106     end
0107     return;
0108 end % else keep going
0109 
0110 
0111 %%
0112 % Calculate CCA given C=[x,y]'*[x,y] and m=size(x,2);
0113 
0114 % sphere x
0115 Cx=C(1:m,1:m);
0116 [V, S] = eig(Cx) ;  
0117 V=real(V); S=real(S);
0118 [E, idx] = sort(diag(S)', 'descend') ;
0119 keep=find(E/max(E)>thresh);
0120 topcs = V(:,idx(keep));
0121 E = E (keep);
0122 EXP=1-10^-12; 
0123 E=E.^EXP; % break symmetry when x and y perfectly correlated (otherwise cols of x*A and y*B are not orthogonal)
0124 A1=topcs*diag(sqrt((1./E)));
0125 
0126 % sphere y
0127 Cy=C(m+1:end,m+1:end);
0128 [V, S] = eig(Cy) ;  
0129 V=real(V); S=real(S);
0130 [E, idx] = sort(diag(S)', 'descend') ;
0131 keep=find(E/max(E)>thresh);
0132 topcs = V(:,idx(keep));
0133 E = E (keep);
0134 E=E.^EXP; %
0135 A2=topcs*diag(sqrt((1./E)));
0136 
0137 % apply sphering matrices to C
0138 AA=zeros( size(A1,1)+size(A2,1), size(A1,2)+size(A2,2) );
0139 AA( 1:size(A1,1), 1:size(A1,2) )=A1;
0140 AA( size(A1,1)+1:end, size(A1,2)+1:end )=A2;
0141 C= AA' * C * AA;
0142 
0143 N=min(size(A1,2),size(A2,2)); % number of canonical components
0144 
0145 % PCA
0146 [V, S] = eig(C) ;
0147 %[V, S] = eigs(C,N) ; % not faster
0148 V=real(V); S=real(S);
0149 [E, idx] = sort(diag(S)', 'descend') ;
0150 topcs = V(:,idx);
0151 
0152 A=A1*topcs(1:size(A1,2),1:N)*sqrt(2);  % why sqrt(2)?...
0153 B=A2*topcs(size(A1,2)+1:end,1:N)*sqrt(2);
0154 R=E(1:N)-1; 
0155 
0156 %{
0157 Why does it work?
0158 If x and y were uncorrelated, eigenvalues E would be all ones. 
0159 Correlated dimensions (the canonical correlates) should give values E>1, 
0160 i.e. they should map to the first PCs. 
0161 To obtain CCs we just select the first N PCs. 
0162 %}
0163 
0164 %%
0165 
0166 %%
0167 % test code
0168 if 0
0169     % basic
0170     clear
0171     x=randn(10000,20);
0172     y=randn(10000,8);
0173     y(:,1:2)=x(:,1:2); % perfectly correlated
0174     y(:,3:4)=x(:,3:4)+randn(10000,2); % 1/2 correlated
0175     y(:,5:6)=x(:,5:6)+randn(10000,2)*3; % 1/4 correlated
0176     y(:,7:8)=randn(10000,2); % uncorrelated
0177     [A,B,R]=nt_cca(x,y);
0178     figure(1); clf
0179     subplot 321; imagesc(A); title('A');
0180     subplot 322; imagesc(B); title('B');
0181     subplot 323; plot(R, '.-'); title('R')
0182     subplot 324; nt_imagescc((x*A)'*(x*A)); title ('covariance of x*A');
0183     subplot 325; nt_imagescc((y*B)'*(y*B)); title ('covariance of y*B');
0184     subplot 326; nt_imagescc([x*A,y*B]'*[x*A,y*B]); title ('covariance of [x*A,y*B]');
0185 end
0186 
0187 if 0 
0188     % compare with canoncorr
0189     clear
0190     x=randn(1000,11);
0191     y=randn(1000,9);
0192     x=x-repmat(mean(x),size(x,1),1); % center, otherwise result may differ slightly from canoncorr
0193     y=y-repmat(mean(y),size(y,1),1);
0194     [A1,B1,R1]=canoncorr(x,y);
0195     [A2,B2,R2]=nt_cca(x,y);   
0196     A2=A2*sqrt(size(x,1)); % scale like canoncorr
0197     B2=B2*sqrt(size(y,1));
0198     figure(1); clf; 
0199     subplot 211; 
0200     plot([R1' R2']); title('R'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none'); 
0201     if mean(A1(:,1).*A2(:,1))<0; A2=-A2; end
0202     subplot 212; 
0203     plot(([x*A1(:,1),x*A2(:,1)])); title('first component'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none'); 
0204     figure(2); clf;set(gcf,'defaulttextinterpreter','none')
0205     subplot 121; 
0206     nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]); title('canoncorr'); 
0207     subplot 122; 
0208     nt_imagescc([x*A2,y*B2]'*[x*A2,y*B2]); title('nt_cca');
0209 end
0210 
0211 if 0
0212     % time
0213     x=randn(100000,100); 
0214     tic; 
0215     [A,B,R]=nt_cca(x,x); 
0216     disp('nt_cca time: ');
0217     toc    
0218     [A,B,R]=canoncorr(x,x); 
0219     disp('canoncorr time: ');
0220     toc
0221 %     [A,B,R]=cca(x,x);
0222 %     disp('cca time: ');
0223 %     toc
0224 end
0225 
0226 if 0
0227     % shifts
0228     x=randn(1000,10);
0229     y=randn(1000,10);
0230     y(:,1:3)=x(:,1:3);
0231     shifts=-10:10;
0232     [A1,B1,R1]=nt_cca(x,y,shifts);
0233     figure(1); clf
0234     plot(shifts,R1'); xlabel('lag'); ylabel('R');
0235 end
0236 
0237 if 0
0238     % what happens if x & y perfectly correlated?
0239     x=randn(1000,10);
0240     y=randn(1000,10); y=x(:,randperm(10)); %+0.000001*y;
0241     [A1,B1,R1]=nt_cca(x,y);
0242     figure(1); clf
0243     nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]);
0244 end    
0245 
0246 if 0
0247     % x and y are cell arrays
0248     x=randn(1000,10); 
0249     y=randn(1000,10);
0250     xx={x,x,x};  yy={x,y,y};
0251     [A,B,R]=nt_cca(xx,yy);
0252     disp('seems to work...');
0253 end
0254 
0255     
0256

Generated on Sat 29-Apr-2023 17:15:46 by m2html © 2005