0001 function [A,B,R]=nt_cca(x,y,lags,C,m,thresh)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036 nt_greetings;
0037
0038 if ~exist('thresh','var');
0039 thresh=10.^-12;
0040 end
0041
0042 if exist('x','var') && ~isempty(x)
0043
0044 if ~exist('y','var'); error('!'); end
0045 if ~exist('lags','var')||isempty('lags'); lags=[0]; end
0046 if numel(lags)==1 && lags==0 && isnumeric(x) && ndims(x)==2;
0047 C=[x,y]'*[x,y];
0048 m=size(x,2);
0049 else
0050 [C,~,m]=nt_cov_lags(x,y,lags);
0051 end
0052 [A,B,R]=nt_cca([],[],[],C,m,thresh);
0053
0054 if nargout==0
0055
0056 if length(lags)>1;
0057 figure(1); clf;
0058 plot(R'); title('correlation for each CC'); xlabel('lag'); ylabel('correlation');
0059 end
0060 end
0061 return
0062 end
0063
0064 if ~exist('C','var') || isempty(C) ; error('!'); end
0065 if ~exist('m','var'); error('!'); end
0066 if size(C,1)~=size(C,2); error('!'); end
0067 if ~isempty(x) || ~isempty(y) || ~isempty(lags) ; error('!'); end
0068 if ndims(C)>3; error('!'); end
0069
0070 if ndims(C) == 3
0071
0072 N=min(m,size(C,1)-m);
0073 A=zeros(m,N,size(C,3));
0074 B=zeros(size(C,1)-m,N,size(C,3));
0075 R=zeros(N,size(C,3));
0076 for k=1:size(C,3);
0077 [AA,BB,RR]=nt_cca([],[],[],C(:,:,k),m);
0078 A(1:size(AA,1),1:size(AA,2),k)=AA;
0079 B(1:size(BB,1),1:size(BB,2),k)=BB;
0080 R(1:size(RR,2),k)=RR;
0081 end
0082 return;
0083 end
0084
0085
0086
0087
0088
0089
0090 Cx=C(1:m,1:m);
0091 [V, S] = eig(Cx) ;
0092 V=real(V); S=real(S);
0093 [E, idx] = sort(diag(S)', 'descend') ;
0094 keep=find(E/max(E)>thresh);
0095 topcs = V(:,idx(keep));
0096 E = E (keep);
0097 EXP=1-10^-12;
0098 E=E.^EXP;
0099 A1=topcs*diag(sqrt((1./E)));
0100
0101
0102 Cy=C(m+1:end,m+1:end);
0103 [V, S] = eig(Cy) ;
0104 V=real(V); S=real(S);
0105 [E, idx] = sort(diag(S)', 'descend') ;
0106 keep=find(E/max(E)>thresh);
0107 topcs = V(:,idx(keep));
0108 E = E (keep);
0109 E=E.^EXP;
0110 A2=topcs*diag(sqrt((1./E)));
0111
0112
0113 AA=zeros( size(A1,1)+size(A2,1), size(A1,2)+size(A2,2) );
0114 AA( 1:size(A1,1), 1:size(A1,2) )=A1;
0115 AA( size(A1,1)+1:end, size(A1,2)+1:end )=A2;
0116 C= AA' * C * AA;
0117
0118 N=min(size(A1,2),size(A2,2));
0119
0120
0121 [V, S] = eig(C) ;
0122
0123 V=real(V); S=real(S);
0124 [E, idx] = sort(diag(S)', 'descend') ;
0125 topcs = V(:,idx);
0126
0127 A=A1*topcs(1:size(A1,2),1:N)*sqrt(2);
0128 B=A2*topcs(size(A1,2)+1:end,1:N)*sqrt(2);
0129 R=E(1:N)-1;
0130
0131
0132
0133 Why does it work?
0134 If x and y were uncorrelated, eigenvalues E would be all ones.
0135 Correlated dimensions (the canonical correlates) should give values E>1,
0136 i.e. they should map to the first PCs.
0137 To obtain CCs we just select the first N PCs.
0138
0139
0140
0141
0142
0143
0144 if 0
0145
0146 clear
0147 x=randn(10000,20);
0148 y=randn(10000,8);
0149 y(:,1:2)=x(:,1:2);
0150 y(:,3:4)=x(:,3:4)+randn(10000,2);
0151 y(:,5:6)=x(:,5:6)+randn(10000,2)*3;
0152 y(:,7:8)=randn(10000,2);
0153 [A,B,R]=nt_cca(x,y);
0154 figure(1); clf
0155 subplot 321; imagesc(A); title('A');
0156 subplot 322; imagesc(B); title('B');
0157 subplot 323; plot(R, '.-'); title('R')
0158 subplot 324; nt_imagescc((x*A)'*(x*A)); title ('covariance of x*A');
0159 subplot 325; nt_imagescc((y*B)'*(y*B)); title ('covariance of y*B');
0160 subplot 326; nt_imagescc([x*A,y*B]'*[x*A,y*B]); title ('covariance of [x*A,y*B]');
0161 end
0162
0163 if 0
0164
0165 clear
0166 x=randn(1000,11);
0167 y=randn(1000,9);
0168 x=x-repmat(mean(x),size(x,1),1);
0169 y=y-repmat(mean(y),size(y,1),1);
0170 [A1,B1,R1]=canoncorr(x,y);
0171 [A2,B2,R2]=nt_cca(x,y);
0172 A2=A2*sqrt(size(x,1));
0173 B2=B2*sqrt(size(y,1));
0174 figure(1); clf;
0175 subplot 211;
0176 plot([R1' R2']); title('R'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none');
0177 if mean(A1(:,1).*A2(:,1))<0; A2=-A2; end
0178 subplot 212;
0179 plot(([x*A1(:,1),x*A2(:,1)])); title('first component'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none');
0180 figure(2); clf;set(gcf,'defaulttextinterpreter','none')
0181 subplot 121;
0182 nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]); title('canoncorr');
0183 subplot 122;
0184 nt_imagescc([x*A2,y*B2]'*[x*A2,y*B2]); title('nt_cca');
0185 end
0186
0187 if 0
0188
0189 x=randn(100000,100);
0190 tic;
0191 [A,B,R]=nt_cca(x,x);
0192 disp('nt_cca time: ');
0193 toc
0194 [A,B,R]=canoncorr(x,x);
0195 disp('canoncorr time: ');
0196 toc
0197
0198
0199
0200 end
0201
0202 if 0
0203
0204 x=randn(1000,10);
0205 y=randn(1000,10);
0206 y(:,1:3)=x(:,1:3);
0207 lags=-10:10;
0208 [A1,B1,R1]=nt_cca(x,y,lags);
0209 figure(1); clf
0210 plot(lags,R1'); xlabel('lag'); ylabel('R');
0211 end
0212
0213 if 0
0214
0215 x=randn(1000,10);
0216 y=randn(1000,10); y=x(:,randperm(10));
0217 [A1,B1,R1]=nt_cca(x,y);
0218 figure(1); clf
0219 nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]);
0220 end
0221
0222 if 0
0223
0224 x=randn(1000,10);
0225 y=randn(1000,10);
0226 xx={x,x,x}; yy={x,y,y};
0227 [A,B,R]=nt_cca(xx,yy);
0228 disp('seems to work...');
0229 end
0230
0231
0232