Home > NoiseTools > nt_cluster_jd.m

nt_cluster_jd

PURPOSE ^

[IDX,todss,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose) - cluster with joint diagonalization

SYNOPSIS ^

function [IDX,TODSS,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose, depth,N)

DESCRIPTION ^

[IDX,todss,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose) - cluster with joint diagonalization

  IDX: cluster ownership (IDX{1}: low amp, IDX{2{: high amp)
  TODSS: DSS matrix (1st column --> discriminating component)
  SCORE: score (smaller means better contrast)
  COVS: covariance for each cluster

  x: data (time*channel% s)
  dsr: downsample ratio for cross product series
  smooth: further smoothing of cross-product series
  flags: see below
  init: provide initial clustering
  verbose: display & plot (default=no)
  depth: cluster recursively into 2^depth clusters
  N: target number of clusters [default: depth^2]

 Flags:
  'norm', 'norm2': give each slice the same weight
  'amp', 'pwr': cluster amplitude or power instead of log (default)
 See nt_bias_cluster, nt_cluster1D

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [IDX,TODSS,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose, depth,N)
0002 %[IDX,todss,SCORE,COVS]=nt_cluster_jd(x,dsr,smooth,flags,init,verbose) - cluster with joint diagonalization
0003 %
0004 %  IDX: cluster ownership (IDX{1}: low amp, IDX{2{: high amp)
0005 %  TODSS: DSS matrix (1st column --> discriminating component)
0006 %  SCORE: score (smaller means better contrast)
0007 %  COVS: covariance for each cluster
0008 %
0009 %  x: data (time*channel% s)
0010 %  dsr: downsample ratio for cross product series
0011 %  smooth: further smoothing of cross-product series
0012 %  flags: see below
0013 %  init: provide initial clustering
0014 %  verbose: display & plot (default=no)
0015 %  depth: cluster recursively into 2^depth clusters
0016 %  N: target number of clusters [default: depth^2]
0017 %
0018 % Flags:
0019 %  'norm', 'norm2': give each slice the same weight
0020 %  'amp', 'pwr': cluster amplitude or power instead of log (default)
0021 % See nt_bias_cluster, nt_cluster1D
0022 
0023 
0024 if nargin<2; error('!'); end
0025 if nargin<3 ||isempty(smooth); smooth=1; end
0026 if nargin<4 ||isempty(flags); flags=[]; end
0027 if nargin<5; init=[]; end
0028 if nargin<6||isempty(verbose); verbose=0; end
0029 if nargin<7||isempty(depth); depth=1; end
0030 if nargin<8||isempty(N); N=2^depth; end
0031 
0032 if ndims(x)>2 || size(x,2) ==1;
0033     error('should be 2D matrix');
0034 end
0035 
0036 if depth>1;
0037     I=nt_cluster_jd2(x,dsr,smooth,flags,init,verbose,1);
0038     if numel(I{1})>2*dsr; I1=nt_cluster_jd2(x(I{1},:),dsr,smooth,flags,init,verbose,depth-1); else I1={(1:numel(I{1}))}; end
0039     if numel(I{2})>2*dsr; I2=nt_cluster_jd2(x(I{2},:),dsr,smooth,flags,init,verbose,depth-1); else I2={(1:numel(I{2}))}; end
0040     IDX1={};
0041     for k=1:numel(I1)
0042         IDX1=[IDX1, I{1}(I1{k})];
0043     end
0044     IDX2={};
0045     for k=1:numel(I2)
0046         IDX2=[IDX2, I{2}(I2{k})];
0047     end
0048     IDX=[IDX1 IDX2];
0049     while numel(IDX)>N;
0050         % merge clusters
0051         COVS=[];
0052         for k=1:numel(IDX)
0053             COVS{k}=nt_cov(x(IDX{k},:))/size(x(IDX{k},:),1);
0054         end
0055         B=covdists(COVS);
0056         [a,idx]=min(B(:));
0057         [k1,k2]=ind2sub([size(B,1), size(B,1)],idx);
0058         
0059         figure(1); clf; nt_imagescc(B);title (num2str([max(B(:)), a, k1, k2, size(B,1)])); pause
0060         IDX{k1}=[IDX{k1};IDX{k2}]; 
0061         IDX(k2)=[];
0062     end
0063     if nargout>1;
0064         c0=nt_cov(x)/size(x,1);
0065         for k=1:numel(IDX)
0066             c1=nt_cov(x(IDX{k},:))/size(x(IDX{k},:),1);
0067             [TODSS{k},pwr0,pwr1]=nt_dss0(c0,c1);
0068             SCORE(k,:)=pwr1./pwr0;
0069             COVS{k}=c1;
0070         end
0071     end
0072     return
0073 end    
0074 
0075 %{
0076  Calculate the time series of cross products (terms of the covariance matrix).
0077  This time series has coarser temporal resolution than x by a factor dsr.
0078 %}
0079 [xx,ind]=nt_xprod(x,'lower',dsr);
0080 if 0
0081     disp([num2str(size(xx,2)), ' crossproducts']);
0082     nt_whoss;
0083 end
0084 
0085 % figure(2); clf;
0086 % subplot 211;
0087 % plot(xx)
0088 
0089 % option: give each slice the same weight (counters amplitude variations)
0090 if find(strcmp(flags,'norm'))
0091     xx=nt_normrow(xx);
0092 end
0093 if find(strcmp(flags,'norm2'))
0094     xx=norm2(xx,size(x,2),ind);
0095 end
0096 
0097 % subplot 212;
0098 % plot(xx);
0099 % pause;
0100 
0101 xx=nt_smooth(xx,smooth,[],1);
0102 
0103 %{
0104 Cluster each column the time series of cross products, 
0105 choose the column with best score (reduction in energy), 
0106 and use it's cluster index to initialize the first JD analysis.
0107 %}
0108 
0109 % initial clustering, DSS
0110 if isempty(init)
0111     [C,A,score]=nt_cluster1D(xx); % cluster all columns of cross products
0112     [~,idx]=min(score); % select column with best score (tightest clusters)
0113     A=A(:,idx); 
0114         
0115     % upsample the cluster ownership index so we can apply it to x
0116     A=repmat(A',[dsr,1]);
0117     A=A(:);
0118     A(end:size(x,1))=A(end);
0119     IDX{1}=find(A==0);
0120 else
0121     IDX{1}=init;
0122 end
0123 c0=nt_cov(x);
0124 c1=nt_cov(x(IDX{1},:));
0125 [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0126 z=nt_mmat(x,todss(:,[1 end])); % keep only first and last components
0127 
0128 PLOT_FIG2=0;
0129 if PLOT_FIG2
0130     figure(2);  clf; set(gcf, 'name','in nt_cluster_jd');
0131     A=zeros(size(x,1),1); A(IDX{1})=1;
0132     subplot 511; plot(x); title('data');
0133     subplot 512; plot(A,'.-'); title('initial cluster map');
0134     subplot 513; plot(z(:,1)); title('initial DSS1');
0135     subplot 514; plot(z(:,2)); title('initial DSS2');
0136     drawnow; pause;
0137 end
0138 
0139 % iterate until stable
0140 old_IDX=IDX{1};
0141 for k=1:10
0142 
0143     [zz,ind]=nt_xprod(z,[],dsr);
0144     zz=zz(:,1:2);       % keep only the squares
0145     
0146     if find(strcmp(flags,'pwr')); % cluster in power
0147         [C,A]=nt_cluster1D(zz);
0148         [~,idx]= max(abs(diff(log2(C+eps)))); % choose first or last
0149     elseif find(strcmp(flags,'amp')); % cluster in amplitude
0150         [C,A]=nt_cluster1D(sqrt(zz));
0151         [~,idx]= max(abs(diff(log2(C+eps)))); % choose first or last
0152     else  % cluster in log domain
0153         [C,A]=nt_cluster1D(log2(zz+eps));
0154         [~,idx]= max(abs(diff(C))); % choose first or last
0155     end
0156     A=A(:,idx);
0157     %disp(C);
0158     C=C(:,idx);
0159     %disp(C); pause
0160     if C(1)<C(2); A=1-A; end % ensure that first cluster has low amplitude
0161     
0162     A=double(nt_smooth(A,smooth, [],1)>=1/smooth); % extend ownership to include effect of smoothing
0163 
0164     % upsample the cluster ownership index so we can apply it to x
0165     A=repmat(A',[dsr,1]); % upsample
0166     A=A(:); 
0167     A(end:size(x,1))=A(end);
0168     IDX{1}=find(A==0); % 0: low values, 1: high values
0169     
0170     % DSS to contrast clusters
0171     c0=nt_cov(x)/size(x,1);
0172     c1=nt_cov(x(IDX{1},:))/size(x(IDX{1},:),1);
0173     [todss,pwr0,pwr1]=nt_dss0(c0,c1);
0174     z=nt_mmat(x,todss(:,[1 end])); % keep first and last
0175 
0176     if ~nargout||verbose; 
0177         disp(['low amp cluster: ', num2str((100*numel(IDX{1})/size(x,1)), 2), ' % of samples, power ratio: ' num2str(pwr1(end)/pwr0(end), 3)]); 
0178         disp(['hi amp cluster: ', num2str((100-100*numel(IDX{1})/size(x,1)), 2), ' % of samples, power ratio: ' num2str(pwr1(1)/pwr0(1), 3)]); 
0179     end
0180 
0181     if PLOT_FIG2
0182         figure(2);  
0183         subplot 515; plot(A,'.-'); title('final cluster map'); pause
0184     end
0185     if all(size(old_IDX)==size(IDX{1})) && all(old_IDX==IDX{1}); break; end
0186     old_IDX=IDX{1};
0187 end 
0188 IDX{2}=setdiff((1:size(x,1))', IDX{1});
0189 
0190 
0191 % final DSS
0192 c0=nt_cov(x)/size(x,1);
0193 c1=nt_cov(x(IDX{1},:))/size(x(IDX{1},:),1);
0194 COVS{1}=c1;
0195 [TODSS{1},pwr0,pwr1]=nt_dss0(c0,c1);
0196 SCORE(1,:)=pwr1./pwr0;
0197 c1=nt_cov(x(IDX{2},:))/size(x(IDX{2},:),1);
0198 COVS{2}=c1;
0199 [TODSS{2},pwr0,pwr1]=nt_dss0(c0,c1);
0200 SCORE(2,:)=pwr1./pwr0;
0201 
0202 if nargout==0||verbose;
0203     
0204     % no output, just plot
0205 
0206     z1=nt_mmat(x,TODSS{1}(:,1));
0207 
0208     figure(101); clf ;
0209     subplot 221;
0210     plot(pwr1./pwr0,'.-'); xlabel('component'); ylabel('score'); title('DSS cluster vs all');
0211     subplot 222;
0212     wsize=min(1024,size(z1,1));
0213     hold on
0214     nt_spect_plot(z1/sqrt(mean(z1(:).^2)),wsize,[],[],1);
0215     nt_spect_plot(x/sqrt(mean(x(:).^2)),wsize,[],[],1);
0216     xlim([0 .5]);
0217     nt_linecolors([],[1 3 2]);
0218     legend('cluster','all'); legend boxoff
0219     hold off
0220 
0221     z=nt_mmat(x,todss); 
0222     z=nt_normcol(z);
0223     subplot 223; imagescc(nt_cov(z(IDX{1},:))); title('cluster 1'); 
0224     subplot 224; imagescc(nt_cov(z)-nt_cov(z(IDX{1},:))); title('cluster 2');
0225 
0226     
0227     figure(102); clf
0228     if 0
0229         subplot 211;
0230         plot(x); hold on
0231         xx=x; xx(IDX{1},:)=nan;
0232         plot(xx,'k');
0233         axis tight
0234         title('black: cluster [high amp]');
0235         subplot 212;
0236         plot(z1); axis tight
0237         title('first DSS component');
0238     else
0239         subplot 311;
0240         plot(x); hold on
0241         xx=x; xx(IDX{1},:)=nan;
0242         plot(xx,'k');
0243         axis tight
0244         title('black: cluster [high amp]');
0245         subplot 312;
0246         plot(z1); axis tight
0247         title('DSS 1');
0248         subplot 313;
0249         nt_sgram(z1,128,1); axis tight
0250         title('DSS 1');
0251     end
0252     
0253     if 0 
0254         figure(105); clf
0255         nt_sgram(z1,1024,32,[],1);
0256         title('DSS1');
0257     end
0258     if nargout==0; clear IDX SCORE TODSS; end
0259     
0260 end
0261 
0262 % can't rememember what this is supposed to do...
0263 function y=norm2(x,nchans,ind)
0264 [I,J]=ind2sub([nchans,nchans],ind); % linear --> matrix indices
0265 for k=1:size(x,1)
0266     a=x(k,1:nchans);
0267     b=sqrt(a(I).*a(J));
0268     y(k,:)=x(k,:)./b;
0269 end
0270 
0271 % matrix of covariance distances
0272 function B=covdists(C) % B: matrix of distances, C: array of covariance matrices
0273 B=nan(numel(C));
0274 for k=1:numel(C)
0275     for j=1:k-1
0276         [E]=eig(C{j},C{k});
0277         B(j,k)=max(abs(log2(E)));
0278         B(k,j)=B(j,k);
0279     end
0280 end
0281     
0282

Generated on Sun 11-Dec-2016 18:36:17 by m2html © 2005