Home > NoiseTools > nt_regw.m

nt_regw

PURPOSE ^

[b,z]=nt_regw(y,x,w) - weighted regression

SYNOPSIS ^

function [b,z]=nt_regw(y,x,w)

DESCRIPTION ^

[b,z]=nt_regw(y,x,w) - weighted regression

  b: regression matrix (apply to x to approximate y)
  z: regression (x*r)

  y: data
  x: regressor
  w: weight to apply to y

  w is either a matrix of same size as y, or a column vector to be applied
  to each column of y

 NoiseTools

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [b,z]=nt_regw(y,x,w)
0002 %[b,z]=nt_regw(y,x,w) - weighted regression
0003 %
0004 %  b: regression matrix (apply to x to approximate y)
0005 %  z: regression (x*r)
0006 %
0007 %  y: data
0008 %  x: regressor
0009 %  w: weight to apply to y
0010 %
0011 %  w is either a matrix of same size as y, or a column vector to be applied
0012 %  to each column of y
0013 %
0014 % NoiseTools
0015 
0016 PCA_THRESH=0.0000001; % discard dimensions of x with eigenvalue lower than this
0017 
0018 if nargin<3; w=[]; end
0019 if nargin<2; error('!'); end
0020 
0021 %% check/fix sizes
0022 m=size(y,1);
0023 x=nt_unfold(x);
0024 y=nt_unfold(y);
0025 if size(x,1)~=size(y,1); error('!'); end
0026 
0027 %% save weighted mean
0028 if nargout>1
0029     mn=y-nt_demean(y,w);
0030 end
0031 
0032 %%
0033 if isempty(w) 
0034     %% simple regression
0035     xx=nt_demean(x);
0036     yy=nt_demean(y);
0037     [V,D]=eig(xx'*xx); V=real(V); D=real(D);
0038     topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0039     xxx=xx*topcs;
0040     b=(yy'*xxx) / (xxx'*xxx); b=b';
0041     if nargout>1; z=nt_demean(x,w)*topcs*b; z=z+mn; end
0042 else
0043     %% weighted regression
0044     if size(w,1)~=size(x,1); error('!'); end
0045     if size(w,2)==1; 
0046         %% same weight for all channels
0047         yy=nt_demean(y,w).*repmat(w,1,size(y,2)); 
0048         xx=nt_demean(x,w).*repmat(w,1,size(x,2));  
0049         [V,D]=eig(xx'*xx); V=real(V); D=real(D);
0050         topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0051         xxx=xx*topcs;
0052         b=(yy'*xxx) / (xxx'*xxx); b=b';
0053         if nargout>1; z=nt_demean(x,w)*topcs*b; z=z+mn; end
0054     else
0055         %% each channel has own weight
0056         if size(w,2) ~= size(y,2); error('!'); end 
0057         if nargout; z=zeros(size(y)); end
0058         for iChan=1:size(y,2)
0059             yy=nt_demean(y(:,iChan),w(:,iChan)) .* w(:,iChan); 
0060             x=nt_demean(x,w(:,iChan)); % remove channel-specific-weighted mean from regressor
0061             xx=x.*repmat(w(:,iChan),1,size(x,2)); 
0062             [V,D]=eig(xx'*xx); V=real(V); D=real(diag(D));
0063             topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0064             xxx=xx*topcs;
0065             b(iChan,1:size(topcs,2))=(yy'*xxx) / (xxx'*xxx); 
0066             if nargout>1; z(:,iChan)=x*(topcs*b(iChan,1:size(topcs,2))') + mn(:,iChan); end
0067         end
0068     end             
0069 end
0070 
0071 %%
0072 if nargout>1;
0073     z=nt_fold(z,m);
0074 end
0075 
0076 %% test code
0077 if 0
0078     % basic
0079     x=randn(100,10); y=randn(100,10); 
0080     b1=nt_regw(x,y); b2=nt_regw(x,x); b3=nt_regw(x,y,ones(size(x))); 
0081     figure(1); subplot 131; nt_imagescc(b1); subplot 132; nt_imagescc(b2); subplot 133; nt_imagescc(b3);
0082 end
0083 if 0
0084     % fit random walk
0085     y=cumsum(randn(1000,1)); x=(1:1000)'; x=[x,x.^2,x.^3];
0086     [b,z]=nt_regw(y,x); 
0087     figure(1); clf; plot([y,z]);
0088 end
0089 if 0
0090     % weights, random
0091     y=cumsum(randn(1000,1)); x=(1:1000)'; x=[x,x.^2,x.^3];
0092     w=rand(size(y));
0093     [b,z]=nt_regw(y,x,w); 
0094     figure(1); clf; plot([y,z]);
0095 end
0096 if 0
0097     % weights, 1st vs 2nd half
0098     y=cumsum(randn(1000,1))+1000; x=(1:1000)'; x=[x,x.^2,x.^3];
0099     w=ones(size(y)); w(1:500,:)=0;
0100     [b,z]=nt_regw(y,x,w); 
0101     figure(1); clf; plot([y,z]);
0102 end
0103 if 0
0104     % multichannel
0105     y=cumsum(randn(1000,2)); x=(1:1000)'; x=[x,x.^2,x.^3];
0106     w=ones(size(y)); 
0107     [b,z]=nt_regw(y,x,w); 
0108     figure(1); clf; plot([y,z]);
0109 end
0110

Generated on Mon 27-Feb-2017 15:36:07 by m2html © 2005