0001 function [IDX,TODSS,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose, depth,N)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024 if nargin<2; error('!'); end
0025 if nargin<3 ||isempty(smooth); smooth=1; end
0026 if nargin<4 ||isempty(flags); flags=[]; end
0027 if nargin<5; init=[]; end
0028 if nargin<6||isempty(verbose); verbose=0; end
0029 if nargin<7||isempty(depth); depth=1; end
0030 if nargin<8||isempty(N); N=2^depth; end
0031
0032 if ndims(x)>2 || size(x,2) ==1;
0033 error('should be 2D matrix');
0034 end
0035
0036 if depth>1;
0037
0038 I=nt_cluster_jd(x,dsr,smooth,flags,init,verbose,1);
0039
0040
0041 if numel(I{1})>2*dsr;
0042 I1=nt_cluster_jd(x(I{1},:),dsr,smooth,flags,init,verbose,depth-1);
0043 else
0044 I1={(1:numel(I{1}))};
0045 end
0046
0047
0048 if numel(I)>1;
0049 if numel(I{2})>2*dsr; I2=nt_cluster_jd(x(I{2},:),dsr,smooth,flags,init,verbose,depth-1); else I2={(1:numel(I{2}))}; end
0050 end
0051
0052
0053 IDX1={};
0054 for k=1:numel(I1)
0055 IDX1=[IDX1, I{1}(I1{k})];
0056 end
0057 IDX2={};
0058 if numel(I)>1
0059 for k=1:numel(I2)
0060 IDX2=[IDX2, I{2}(I2{k})];
0061 end
0062 end
0063 IDX=[IDX1 IDX2];
0064 checkindex(IDX,size(x,1))
0065
0066 while numel(IDX)>N;
0067
0068 COVS=[];
0069 for k=1:numel(IDX)
0070 COVS{k}=nt_cov(x(IDX{k},:))/size(x(IDX{k},:),1);
0071 end
0072 B=covdists(COVS);
0073 [a,idx]=min(B(:));
0074 [k1,k2]=ind2sub([size(B,1), size(B,1)],idx);
0075
0076
0077 IDX{k1}=[IDX{k1};IDX{k2}];
0078 IDX(k2)=[];
0079 checkindex(IDX,size(x,1))
0080 end
0081 if nargout>1;
0082 c0=nt_cov(x)/size(x,1);
0083 for k=1:numel(IDX)
0084 c1=nt_cov(x(IDX{k},:))/size(x(IDX{k},:),1);
0085 [TODSS{k},pwr0,pwr1]=nt_dss0(c0,c1);
0086 SCORE(k,1:numel(pwr1))=pwr1./pwr0;
0087 COVS{k}=c1;
0088 end
0089 end
0090 return
0091 end
0092
0093
0094 Calculate the time series of cross products (terms of the covariance matrix).
0095 This time series has coarser temporal resolution than x by a factor dsr.
0096
0097 [xx,ind]=nt_xprod(x,'lower',dsr);
0098 if 0
0099 disp([num2str(size(xx,2)), ' crossproducts']);
0100 nt_whoss;
0101 end
0102
0103
0104
0105
0106
0107
0108 if find(strcmp(flags,'norm'))
0109 xx=nt_normrow(xx);
0110 end
0111 if find(strcmp(flags,'norm2'))
0112 xx=norm2(xx,size(x,2),ind);
0113 end
0114
0115
0116
0117
0118
0119 xx=nt_smooth(xx,smooth,[],1);
0120
0121
0122 Cluster each column the time series of cross products,
0123 choose the column with best score (reduction in energy),
0124 and use it's cluster index to initialize the first JD analysis.
0125
0126
0127
0128 if isempty(init)
0129 [C,A,score]=nt_cluster1D(xx);
0130 [~,idx]=min(score);
0131 A=A(:,idx);
0132
0133
0134 A=repmat(A',[dsr,1]);
0135 A=A(:);
0136 A(end:size(x,1))=A(end);
0137 IDX{1}=find(A==0);
0138 else
0139 IDX{1}=init;
0140 end
0141
0142 if isempty(IDX{1})
0143 IDX{1}=1:size(x,1);
0144 TODSS{1}=nan;
0145 SCORE{1}=nan;
0146 COVS{1}=nt_cov(x);
0147 return
0148 end
0149
0150 c0=nt_cov(x);
0151 c1=nt_cov(x(IDX{1},:));
0152 [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0153 z=nt_mmat(x,todss(:,[1 end]));
0154
0155 PLOT_FIG2=0;
0156 if PLOT_FIG2
0157 figure(2); clf; set(gcf, 'name','in nt_cluster_jd');
0158 A=zeros(size(x,1),1); A(IDX{1})=1;
0159 subplot 511; plot(x); title('data');
0160 subplot 512; plot(A,'.-'); title('initial cluster map');
0161 subplot 513; plot(z(:,1)); title('initial DSS1');
0162 subplot 514; plot(z(:,2)); title('initial DSS2');
0163 drawnow; pause;
0164 end
0165
0166
0167 old_IDX=IDX{1};
0168 for k=1:10
0169
0170 [zz,ind]=nt_xprod(z,[],dsr);
0171 zz=zz(:,1:2);
0172
0173 if find(strcmp(flags,'pwr'));
0174 [C,A]=nt_cluster1D(zz);
0175 [~,idx]= max(abs(diff(log2(C+eps))));
0176 elseif find(strcmp(flags,'amp'));
0177 [C,A]=nt_cluster1D(sqrt(zz));
0178 [~,idx]= max(abs(diff(log2(C+eps))));
0179 else
0180 [C,A]=nt_cluster1D(log2(zz+eps));
0181 [~,idx]= max(abs(diff(C)));
0182 end
0183 A=A(:,idx);
0184
0185 C=C(:,idx);
0186
0187 if C(1)<C(2); A=1-A; end
0188
0189 A=double(nt_smooth(A,smooth, [],1)>=1/smooth);
0190
0191
0192 A=repmat(A',[dsr,1]);
0193 A=A(:);
0194 A(end:size(x,1))=A(end);
0195 IDX{1}=find(A==0);
0196
0197 if isempty(IDX{1})
0198 IDX{1}=1:size(x,1);
0199 TODSS{1}=nan;
0200 SCORE{1}=nan;
0201 COVS{1}=nt_cov(x);
0202 return
0203 end
0204
0205
0206 c0=nt_cov(x)/size(x,1);
0207 c1=nt_cov(x(IDX{1},:))/size(x(IDX{1},:),1);
0208 [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0209 z=nt_mmat(x,todss(:,[1 end]));
0210
0211 if ~nargout||verbose;
0212 disp(['low amp cluster: ', num2str((100*numel(IDX{1})/size(x,1)), 2), ' % of samples, power ratio: ' num2str(pwr1(end)/pwr0(end), 3)]);
0213 disp(['hi amp cluster: ', num2str((100-100*numel(IDX{1})/size(x,1)), 2), ' % of samples, power ratio: ' num2str(pwr1(1)/pwr0(1), 3)]);
0214 end
0215
0216 if PLOT_FIG2
0217 figure(2);
0218 subplot 515; plot(A,'.-'); title('final cluster map'); pause
0219 end
0220 if all(size(old_IDX)==size(IDX{1})) && all(old_IDX==IDX{1}); break; end
0221 old_IDX=IDX{1};
0222 end
0223 IDX{2}=setdiff((1:size(x,1))', IDX{1});
0224
0225
0226
0227 c0=nt_cov(x)/size(x,1);
0228 c1=nt_cov(x(IDX{1},:))/size(x(IDX{1},:),1);
0229 COVS{1}=c1;
0230 [TODSS{1},pwr0,pwr1]=nt_dss0(c0,c1);
0231 SCORE(1,1:numel(pwr1))=pwr1./pwr0;
0232 c1=nt_cov(x(IDX{2},:))/size(x(IDX{2},:),1);
0233 COVS{2}=c1;
0234 [TODSS{2},pwr0,pwr1]=nt_dss0(c0,c1);
0235 SCORE(2,1:numel(pwr1))=pwr1./pwr0;
0236
0237 if nargout==0||verbose;
0238
0239
0240
0241 z1=nt_mmat(x,TODSS{1}(:,1));
0242
0243 figure(101); clf ;
0244 subplot 221;
0245 plot(pwr1./pwr0,'.-'); xlabel('component'); ylabel('score'); title('DSS cluster vs all');
0246 subplot 222;
0247 wsize=min(1024,size(z1,1));
0248 hold on
0249 nt_spect_plot(z1/sqrt(mean(z1(:).^2)),wsize,[],[],1);
0250 nt_spect_plot(x/sqrt(mean(x(:).^2)),wsize,[],[],1);
0251 xlim([0 .5]);
0252 nt_linecolors([],[1 3 2]);
0253 legend('cluster','all'); legend boxoff
0254 hold off
0255
0256 z=nt_mmat(x,todss);
0257 z=nt_normcol(z);
0258 subplot 223; nt_imagescc(nt_cov(z(IDX{1},:))); title('cluster 1');
0259 subplot 224; nt_imagescc(nt_cov(z)-nt_cov(z(IDX{1},:))); title('cluster 2');
0260
0261
0262 figure(102); clf
0263 if 0
0264 subplot 211;
0265 plot(x); hold on
0266 xx=x; xx(IDX{1},:)=nan;
0267 plot(xx,'k');
0268 axis tight
0269 title('black: cluster [high amp]');
0270 subplot 212;
0271 plot(z1); axis tight
0272 title('first DSS component');
0273 else
0274 subplot 311;
0275 plot(x); hold on
0276 xx=x; xx(IDX{1},:)=nan;
0277 plot(xx,'k');
0278 axis tight
0279 title('black: cluster [high amp]');
0280 subplot 312;
0281 plot(z1); axis tight
0282 title('DSS 1');
0283 subplot 313;
0284 nt_sgram(z1,128,1); axis tight
0285 title('DSS 1');
0286 end
0287
0288 if 0
0289 figure(105); clf
0290 nt_sgram(z1,1024,32,[],1);
0291 title('DSS1');
0292 end
0293 if nargout==0; clear IDX SCORE TODSS; end
0294
0295 end
0296
0297
0298 function y=norm2(x,nchans,ind)
0299 [I,J]=ind2sub([nchans,nchans],ind);
0300 for k=1:size(x,1)
0301 a=x(k,1:nchans);
0302 b=sqrt(a(I).*a(J));
0303 y(k,:)=x(k,:)./b;
0304 end
0305
0306
0307 function B=covdists(C)
0308 B=nan(numel(C));
0309 CC=zeros(size(C{1}));
0310 for k=1:numel(C); CC=CC+C{k}; end
0311 for k=1:numel(C)
0312 for j=1:k-1
0313 [E]=eig(abs(C{j}-C{k}),CC);
0314 B(j,k)=max(abs(log2(E)));
0315 B(k,j)=B(j,k);
0316 end
0317 end
0318
0319 function checkindex(IDX,n)
0320 a=zeros(n,1);
0321 for k=1:numel(IDX)
0322 if any(a(IDX{k}));
0323 error('!');
0324 end
0325 a(IDX{k})=1;
0326 end
0327
0328