Home > NoiseTools > nt_cluster_jd.m

nt_cluster_jd

PURPOSE ^

[IDX,todss,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose) - cluster with joint diagonalization

SYNOPSIS ^

function [IDX,TODSS,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose, depth,N)

DESCRIPTION ^

[IDX,todss,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose) - cluster with joint diagonalization

  IDX: cluster ownership (IDX{1}: low amp, IDX{2{: high amp)
  TODSS: DSS matrix (1st column --> discriminating component)
  SCORE: score (smaller means better contrast)
  COVS: covariance for each cluster

  x: data (time*channel% s)
  dsr: downsample ratio for cross product series
  smooth: further smoothing of cross-product series
  flags: see below
  init: provide initial clustering
  verbose: display & plot (default=no)
  depth: cluster recursively into 2^depth clusters
  N: target number of clusters [default: depth^2]

 Flags:
  'norm', 'norm2': give each slice the same weight
  'amp', 'pwr': cluster amplitude or power instead of log (default)
 See nt_bias_cluster, nt_cluster1D

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [IDX,TODSS,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose, depth,N)
0002 %[IDX,todss,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose) - cluster with joint diagonalization
0003 %
0004 %  IDX: cluster ownership (IDX{1}: low amp, IDX{2{: high amp)
0005 %  TODSS: DSS matrix (1st column --> discriminating component)
0006 %  SCORE: score (smaller means better contrast)
0007 %  COVS: covariance for each cluster
0008 %
0009 %  x: data (time*channel% s)
0010 %  dsr: downsample ratio for cross product series
0011 %  smooth: further smoothing of cross-product series
0012 %  flags: see below
0013 %  init: provide initial clustering
0014 %  verbose: display & plot (default=no)
0015 %  depth: cluster recursively into 2^depth clusters
0016 %  N: target number of clusters [default: depth^2]
0017 %
0018 % Flags:
0019 %  'norm', 'norm2': give each slice the same weight
0020 %  'amp', 'pwr': cluster amplitude or power instead of log (default)
0021 % See nt_bias_cluster, nt_cluster1D
0022 
0023 
0024 if nargin<2; error('!'); end
0025 if nargin<3 ||isempty(smooth); smooth=1; end
0026 if nargin<4 ||isempty(flags); flags=[]; end
0027 if nargin<5; init=[]; end
0028 if nargin<6||isempty(verbose); verbose=0; end
0029 if nargin<7||isempty(depth); depth=1; end
0030 if nargin<8||isempty(N); N=2^depth; end
0031 
0032 if ndims(x)>2 || size(x,2) ==1;
0033     error('should be 2D matrix');
0034 end
0035 
0036 if depth>1;
0037     % split into 2 clusters
0038     I=nt_cluster_jd(x,dsr,smooth,flags,init,verbose,1);
0039     
0040     % recurse on first
0041     if numel(I{1})>2*dsr; 
0042         I1=nt_cluster_jd(x(I{1},:),dsr,smooth,flags,init,verbose,depth-1); % recurse
0043     else
0044         I1={(1:numel(I{1}))}; % too small
0045     end
0046     
0047     % recurse on second (if exists)
0048     if numel(I)>1; 
0049         if numel(I{2})>2*dsr; I2=nt_cluster_jd(x(I{2},:),dsr,smooth,flags,init,verbose,depth-1); else I2={(1:numel(I{2}))}; end
0050     end
0051     
0052     % resolve the cluster indices
0053     IDX1={};
0054     for k=1:numel(I1)
0055         IDX1=[IDX1, I{1}(I1{k})];
0056     end
0057     IDX2={};
0058     if numel(I)>1
0059         for k=1:numel(I2)
0060             IDX2=[IDX2, I{2}(I2{k})];
0061         end
0062     end
0063     IDX=[IDX1 IDX2];
0064     checkindex(IDX,size(x,1))
0065     
0066     while numel(IDX)>N;
0067         % merge clusters
0068         COVS=[];
0069         for k=1:numel(IDX)
0070             COVS{k}=nt_cov(x(IDX{k},:))/size(x(IDX{k},:),1);
0071         end
0072         B=covdists(COVS);
0073         [a,idx]=min(B(:));
0074         [k1,k2]=ind2sub([size(B,1), size(B,1)],idx);
0075         
0076         %figure(1); clf; nt_imagescc(B);title (num2str([max(B(:)), a, k1, k2, size(B,1)])); pause
0077         IDX{k1}=[IDX{k1};IDX{k2}]; 
0078         IDX(k2)=[];
0079         checkindex(IDX,size(x,1))
0080     end
0081     if nargout>1;
0082         c0=nt_cov(x)/size(x,1);
0083         for k=1:numel(IDX)
0084             c1=nt_cov(x(IDX{k},:))/size(x(IDX{k},:),1);
0085             [TODSS{k},pwr0,pwr1]=nt_dss0(c0,c1);
0086             SCORE(k,1:numel(pwr1))=pwr1./pwr0;
0087             COVS{k}=c1;
0088         end
0089     end
0090     return
0091 end    
0092 
0093 %{
0094  Calculate the time series of cross products (terms of the covariance matrix).
0095  This time series has coarser temporal resolution than x by a factor dsr.
0096 %}
0097 [xx,ind]=nt_xprod(x,'lower',dsr);
0098 if 0
0099     disp([num2str(size(xx,2)), ' crossproducts']);
0100     nt_whoss;
0101 end
0102 
0103 % figure(2); clf;
0104 % subplot 211;
0105 % plot(xx)
0106 
0107 % option: give each slice the same weight (counters amplitude variations)
0108 if find(strcmp(flags,'norm'))
0109     xx=nt_normrow(xx);
0110 end
0111 if find(strcmp(flags,'norm2'))
0112     xx=norm2(xx,size(x,2),ind);
0113 end
0114 
0115 % subplot 212;
0116 % plot(xx);
0117 % pause;
0118 
0119 xx=nt_smooth(xx,smooth,[],1);
0120 
0121 %{
0122 Cluster each column the time series of cross products, 
0123 choose the column with best score (reduction in energy), 
0124 and use it's cluster index to initialize the first JD analysis.
0125 %}
0126 
0127 % initial clustering, DSS
0128 if isempty(init)
0129     [C,A,score]=nt_cluster1D(xx); % cluster all columns of cross products
0130     [~,idx]=min(score); % select column with best score (tightest clusters)
0131     A=A(:,idx); 
0132         
0133     % upsample the cluster ownership index so we can apply it to x
0134     A=repmat(A',[dsr,1]);
0135     A=A(:);
0136     A(end:size(x,1))=A(end);
0137     IDX{1}=find(A==0);
0138 else
0139     IDX{1}=init;
0140 end
0141 
0142 if isempty(IDX{1}) % clustering failed, return just one cluster
0143     IDX{1}=1:size(x,1);
0144     TODSS{1}=nan;
0145     SCORE{1}=nan;
0146     COVS{1}=nt_cov(x);
0147     return
0148 end
0149    
0150 c0=nt_cov(x);
0151 c1=nt_cov(x(IDX{1},:));
0152 [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0153 z=nt_mmat(x,todss(:,[1 end])); % keep only first and last components
0154 
0155 PLOT_FIG2=0;
0156 if PLOT_FIG2
0157     figure(2);  clf; set(gcf, 'name','in nt_cluster_jd');
0158     A=zeros(size(x,1),1); A(IDX{1})=1;
0159     subplot 511; plot(x); title('data');
0160     subplot 512; plot(A,'.-'); title('initial cluster map');
0161     subplot 513; plot(z(:,1)); title('initial DSS1');
0162     subplot 514; plot(z(:,2)); title('initial DSS2');
0163     drawnow; pause;
0164 end
0165 
0166 % iterate until stable
0167 old_IDX=IDX{1};
0168 for k=1:10
0169 
0170     [zz,ind]=nt_xprod(z,[],dsr);
0171     zz=zz(:,1:2);       % keep only the squares
0172     
0173     if find(strcmp(flags,'pwr')); % cluster in power
0174         [C,A]=nt_cluster1D(zz);
0175         [~,idx]= max(abs(diff(log2(C+eps)))); % choose first or last
0176     elseif find(strcmp(flags,'amp')); % cluster in amplitude
0177         [C,A]=nt_cluster1D(sqrt(zz));
0178         [~,idx]= max(abs(diff(log2(C+eps)))); % choose first or last
0179     else  % cluster in log domain
0180         [C,A]=nt_cluster1D(log2(zz+eps));
0181         [~,idx]= max(abs(diff(C))); % choose first or last
0182     end
0183     A=A(:,idx);
0184     %disp(C);
0185     C=C(:,idx);
0186     %disp(C); pause
0187     if C(1)<C(2); A=1-A; end % ensure that first cluster has low amplitude
0188     
0189     A=double(nt_smooth(A,smooth, [],1)>=1/smooth); % extend ownership to include effect of smoothing
0190 
0191     % upsample the cluster ownership index so we can apply it to x
0192     A=repmat(A',[dsr,1]); % upsample
0193     A=A(:); 
0194     A(end:size(x,1))=A(end);
0195     IDX{1}=find(A==0); % 0: low values, 1: high values
0196     
0197     if isempty(IDX{1}) % clustering failed, return just one cluster
0198         IDX{1}=1:size(x,1);
0199         TODSS{1}=nan;
0200         SCORE{1}=nan;
0201         COVS{1}=nt_cov(x);
0202         return
0203     end
0204 
0205     % DSS to contrast clusters
0206     c0=nt_cov(x)/size(x,1);
0207     c1=nt_cov(x(IDX{1},:))/size(x(IDX{1},:),1);
0208     [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0209     z=nt_mmat(x,todss(:,[1 end])); % keep first and last
0210 
0211     if ~nargout||verbose; 
0212         disp(['low amp cluster: ', num2str((100*numel(IDX{1})/size(x,1)), 2), ' % of samples, power ratio: ' num2str(pwr1(end)/pwr0(end), 3)]); 
0213         disp(['hi amp cluster: ', num2str((100-100*numel(IDX{1})/size(x,1)), 2), ' % of samples, power ratio: ' num2str(pwr1(1)/pwr0(1), 3)]); 
0214     end
0215 
0216     if PLOT_FIG2
0217         figure(2);  
0218         subplot 515; plot(A,'.-'); title('final cluster map'); pause
0219     end
0220     if all(size(old_IDX)==size(IDX{1})) && all(old_IDX==IDX{1}); break; end
0221     old_IDX=IDX{1};
0222 end 
0223 IDX{2}=setdiff((1:size(x,1))', IDX{1});
0224 
0225 
0226 % final DSS
0227 c0=nt_cov(x)/size(x,1);
0228 c1=nt_cov(x(IDX{1},:))/size(x(IDX{1},:),1);
0229 COVS{1}=c1;
0230 [TODSS{1},pwr0,pwr1]=nt_dss0(c0,c1);
0231 SCORE(1,1:numel(pwr1))=pwr1./pwr0;
0232 c1=nt_cov(x(IDX{2},:))/size(x(IDX{2},:),1);
0233 COVS{2}=c1;
0234 [TODSS{2},pwr0,pwr1]=nt_dss0(c0,c1);
0235 SCORE(2,1:numel(pwr1))=pwr1./pwr0;
0236 
0237 if nargout==0||verbose;
0238     
0239     % no output, just plot
0240 
0241     z1=nt_mmat(x,TODSS{1}(:,1));
0242 
0243     figure(101); clf ;
0244     subplot 221;
0245     plot(pwr1./pwr0,'.-'); xlabel('component'); ylabel('score'); title('DSS cluster vs all');
0246     subplot 222;
0247     wsize=min(1024,size(z1,1));
0248     hold on
0249     nt_spect_plot(z1/sqrt(mean(z1(:).^2)),wsize,[],[],1);
0250     nt_spect_plot(x/sqrt(mean(x(:).^2)),wsize,[],[],1);
0251     xlim([0 .5]);
0252     nt_linecolors([],[1 3 2]);
0253     legend('cluster','all'); legend boxoff
0254     hold off
0255 
0256     z=nt_mmat(x,todss); 
0257     z=nt_normcol(z);
0258     subplot 223; nt_imagescc(nt_cov(z(IDX{1},:))); title('cluster 1'); 
0259     subplot 224; nt_imagescc(nt_cov(z)-nt_cov(z(IDX{1},:))); title('cluster 2');
0260 
0261     
0262     figure(102); clf
0263     if 0
0264         subplot 211;
0265         plot(x); hold on
0266         xx=x; xx(IDX{1},:)=nan;
0267         plot(xx,'k');
0268         axis tight
0269         title('black: cluster [high amp]');
0270         subplot 212;
0271         plot(z1); axis tight
0272         title('first DSS component');
0273     else
0274         subplot 311;
0275         plot(x); hold on
0276         xx=x; xx(IDX{1},:)=nan;
0277         plot(xx,'k');
0278         axis tight
0279         title('black: cluster [high amp]');
0280         subplot 312;
0281         plot(z1); axis tight
0282         title('DSS 1');
0283         subplot 313;
0284         nt_sgram(z1,128,1); axis tight
0285         title('DSS 1');
0286     end
0287     
0288     if 0 
0289         figure(105); clf
0290         nt_sgram(z1,1024,32,[],1);
0291         title('DSS1');
0292     end
0293     if nargout==0; clear IDX SCORE TODSS; end
0294     
0295 end
0296 
0297 % can't rememember what this is supposed to do...
0298 function y=norm2(x,nchans,ind)
0299 [I,J]=ind2sub([nchans,nchans],ind); % linear --> matrix indices
0300 for k=1:size(x,1)
0301     a=x(k,1:nchans);
0302     b=sqrt(a(I).*a(J));
0303     y(k,:)=x(k,:)./b;
0304 end
0305 
0306 % matrix of covariance distances
0307 function B=covdists(C) % B: matrix of distances, C: array of covariance matrices
0308 B=nan(numel(C));
0309 CC=zeros(size(C{1}));
0310 for k=1:numel(C); CC=CC+C{k}; end
0311 for k=1:numel(C)
0312     for j=1:k-1
0313         [E]=eig(abs(C{j}-C{k}),CC);
0314         B(j,k)=max(abs(log2(E)));
0315         B(k,j)=B(j,k);
0316     end
0317 end
0318     
0319 function checkindex(IDX,n)
0320 a=zeros(n,1);
0321 for k=1:numel(IDX)
0322     if any(a(IDX{k})); 
0323         error('!');
0324     end
0325     a(IDX{k})=1;
0326 end
0327     
0328

Generated on Sat 29-Apr-2023 17:15:46 by m2html © 2005