0001 function [A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 nt_greetings;
0038
0039 if nargin<7||isempty(nodemeanflag); demeanflag=1; end
0040
0041 if ~exist('thresh','var');
0042 thresh=10.^-12;
0043 end
0044
0045 if exist('x','var') && ~isempty(x)
0046
0047
0048 if ~exist('y','var'); error('!'); end
0049 if ~exist('shifts','var')||isempty('shifts'); shifts=[0]; end
0050
0051
0052 if iscell(x) && size(x{1},1) ~= size(y{1},1)
0053 for iTrial=1:numel(x)
0054 tmp=min(size(x{iTrial},1),size(y{iTrial},1));
0055 x{iTrial}=x{iTrial}(1:tmp,:);
0056 y{iTrial}=y{iTrial}(1:tmp,:);
0057 end
0058 else
0059 tmp=min(size(x,1),size(y,1));
0060 x=x(1:tmp,:,:);
0061 y=y(1:tmp,:,:);
0062 end
0063
0064
0065 if numel(shifts)==1 && shifts==0 && isnumeric(x) && ndims(x)==2;
0066 if demeanflag
0067 x=nt_demean(x);
0068 y=nt_demean(y);
0069 end
0070 C=[x,y]'*[x,y];
0071 m=size(x,2);
0072 else
0073 [C,~,m]=nt_cov_lags(x,y,shifts,demeanflag);
0074 end
0075
0076 [A,B,R]=nt_cca([],[],[],C,m,thresh);
0077
0078 if nargout==0
0079
0080 if length(shifts)>1;
0081 figure(1); clf;
0082 plot(R'); title('correlation for each CC'); xlabel('lag'); ylabel('correlation');
0083 end
0084 end
0085 return
0086 end
0087
0088
0089 if ~isempty(x) || ~isempty(y) || ~isempty(shifts) ; error('!'); end
0090 if ~exist('C','var') || isempty(C) ; error('!'); end
0091 if ~exist('m','var'); error('!'); end
0092 if size(C,1)~=size(C,2); error('!'); end
0093 if ndims(C)>3; error('!'); end
0094
0095 if ndims(C) == 3
0096
0097 N=min(m,size(C,1)-m);
0098 A=zeros(m,N,size(C,3));
0099 B=zeros(size(C,1)-m,N,size(C,3));
0100 R=zeros(N,size(C,3));
0101 for k=1:size(C,3);
0102 [AA,BB,RR]=nt_cca([],[],[],C(:,:,k),m);
0103 A(1:size(AA,1),1:size(AA,2),k)=AA;
0104 B(1:size(BB,1),1:size(BB,2),k)=BB;
0105 R(1:size(RR,2),k)=RR;
0106 end
0107 return;
0108 end
0109
0110
0111
0112
0113
0114
0115 Cx=C(1:m,1:m);
0116 [V, S] = eig(Cx) ;
0117 V=real(V); S=real(S);
0118 [E, idx] = sort(diag(S)', 'descend') ;
0119 keep=find(E/max(E)>thresh);
0120 topcs = V(:,idx(keep));
0121 E = E (keep);
0122 EXP=1-10^-12;
0123 E=E.^EXP;
0124 A1=topcs*diag(sqrt((1./E)));
0125
0126
0127 Cy=C(m+1:end,m+1:end);
0128 [V, S] = eig(Cy) ;
0129 V=real(V); S=real(S);
0130 [E, idx] = sort(diag(S)', 'descend') ;
0131 keep=find(E/max(E)>thresh);
0132 topcs = V(:,idx(keep));
0133 E = E (keep);
0134 E=E.^EXP;
0135 A2=topcs*diag(sqrt((1./E)));
0136
0137
0138 AA=zeros( size(A1,1)+size(A2,1), size(A1,2)+size(A2,2) );
0139 AA( 1:size(A1,1), 1:size(A1,2) )=A1;
0140 AA( size(A1,1)+1:end, size(A1,2)+1:end )=A2;
0141 C= AA' * C * AA;
0142
0143 N=min(size(A1,2),size(A2,2));
0144
0145
0146 [V, S] = eig(C) ;
0147
0148 V=real(V); S=real(S);
0149 [E, idx] = sort(diag(S)', 'descend') ;
0150 topcs = V(:,idx);
0151
0152 A=A1*topcs(1:size(A1,2),1:N)*sqrt(2);
0153 B=A2*topcs(size(A1,2)+1:end,1:N)*sqrt(2);
0154 R=E(1:N)-1;
0155
0156
0157 Why does it work?
0158 If x and y were uncorrelated, eigenvalues E would be all ones.
0159 Correlated dimensions (the canonical correlates) should give values E>1,
0160 i.e. they should map to the first PCs.
0161 To obtain CCs we just select the first N PCs.
0162
0163
0164
0165
0166
0167
0168 if 0
0169
0170 clear
0171 x=randn(10000,20);
0172 y=randn(10000,8);
0173 y(:,1:2)=x(:,1:2);
0174 y(:,3:4)=x(:,3:4)+randn(10000,2);
0175 y(:,5:6)=x(:,5:6)+randn(10000,2)*3;
0176 y(:,7:8)=randn(10000,2);
0177 [A,B,R]=nt_cca(x,y);
0178 figure(1); clf
0179 subplot 321; imagesc(A); title('A');
0180 subplot 322; imagesc(B); title('B');
0181 subplot 323; plot(R, '.-'); title('R')
0182 subplot 324; nt_imagescc((x*A)'*(x*A)); title ('covariance of x*A');
0183 subplot 325; nt_imagescc((y*B)'*(y*B)); title ('covariance of y*B');
0184 subplot 326; nt_imagescc([x*A,y*B]'*[x*A,y*B]); title ('covariance of [x*A,y*B]');
0185 end
0186
0187 if 0
0188
0189 clear
0190 x=randn(1000,11);
0191 y=randn(1000,9);
0192 x=x-repmat(mean(x),size(x,1),1);
0193 y=y-repmat(mean(y),size(y,1),1);
0194 [A1,B1,R1]=canoncorr(x,y);
0195 [A2,B2,R2]=nt_cca(x,y);
0196 A2=A2*sqrt(size(x,1));
0197 B2=B2*sqrt(size(y,1));
0198 figure(1); clf;
0199 subplot 211;
0200 plot([R1' R2']); title('R'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none');
0201 if mean(A1(:,1).*A2(:,1))<0; A2=-A2; end
0202 subplot 212;
0203 plot(([x*A1(:,1),x*A2(:,1)])); title('first component'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none');
0204 figure(2); clf;set(gcf,'defaulttextinterpreter','none')
0205 subplot 121;
0206 nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]); title('canoncorr');
0207 subplot 122;
0208 nt_imagescc([x*A2,y*B2]'*[x*A2,y*B2]); title('nt_cca');
0209 end
0210
0211 if 0
0212
0213 x=randn(100000,100);
0214 tic;
0215 [A,B,R]=nt_cca(x,x);
0216 disp('nt_cca time: ');
0217 toc
0218 [A,B,R]=canoncorr(x,x);
0219 disp('canoncorr time: ');
0220 toc
0221
0222
0223
0224 end
0225
0226 if 0
0227
0228 x=randn(1000,10);
0229 y=randn(1000,10);
0230 y(:,1:3)=x(:,1:3);
0231 shifts=-10:10;
0232 [A1,B1,R1]=nt_cca(x,y,shifts);
0233 figure(1); clf
0234 plot(shifts,R1'); xlabel('lag'); ylabel('R');
0235 end
0236
0237 if 0
0238
0239 x=randn(1000,10);
0240 y=randn(1000,10); y=x(:,randperm(10));
0241 [A1,B1,R1]=nt_cca(x,y);
0242 figure(1); clf
0243 nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]);
0244 end
0245
0246 if 0
0247
0248 x=randn(1000,10);
0249 y=randn(1000,10);
0250 xx={x,x,x}; yy={x,y,y};
0251 [A,B,R]=nt_cca(xx,yy);
0252 disp('seems to work...');
0253 end
0254
0255
0256