Home > NoiseTools > nt_regw.m

nt_regw

PURPOSE ^

[b,z]=nt_regw(y,x,w) - weighted regression

SYNOPSIS ^

function [b,z]=nt_regw(y,x,w)

DESCRIPTION ^

[b,z]=nt_regw(y,x,w) - weighted regression

  b: regression matrix (apply to x to approximate y)
  z: regression (x*b)

  y: data
  x: regressor
  w: weight to apply to y

  w is either a matrix of same size as y, or a column vector to be applied
  to each column of y

 NoiseTools

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [b,z]=nt_regw(y,x,w)
0002 %[b,z]=nt_regw(y,x,w) - weighted regression
0003 %
0004 %  b: regression matrix (apply to x to approximate y)
0005 %  z: regression (x*b)
0006 %
0007 %  y: data
0008 %  x: regressor
0009 %  w: weight to apply to y
0010 %
0011 %  w is either a matrix of same size as y, or a column vector to be applied
0012 %  to each column of y
0013 %
0014 % NoiseTools
0015 
0016 PCA_THRESH=0.0000001; % discard dimensions of x with eigenvalue lower than this
0017 
0018 if nargin<3; w=[]; end
0019 if nargin<2; error('!'); end
0020 
0021 %% check/fix sizes
0022 m=size(y,1);
0023 x=nt_unfold(x);
0024 y=nt_unfold(y);
0025 if size(x,1)~=size(y,1); 
0026     disp(size(x)); disp(size(y)); error('!'); 
0027 end
0028 
0029 %% save weighted mean
0030 if nargout>1
0031     mn=y-nt_demean(y,w);
0032 end
0033 
0034 %%
0035 if isempty(w) 
0036     %% simple regression
0037     xx=nt_demean(x);
0038     yy=nt_demean(y);
0039     [V,D]=eig(xx'*xx); V=real(V); D=real(D);
0040     topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0041     xxx=xx*topcs;
0042     b=(yy'*xxx) / (xxx'*xxx); b=b';
0043     if nargout>1; z=nt_demean(x,w)*topcs*b; z=z+mn; end
0044 else
0045     %% weighted regression
0046     if size(w,1)~=size(x,1); error('!'); end
0047     if size(w,2)==1; 
0048         %% same weight for all channels
0049         if sum(w(:))==0; 
0050             %warning('weights all zero');
0051             b=0;
0052         else
0053             yy=nt_demean(y,w).*repmat(w,1,size(y,2)); 
0054             xx=nt_demean(x,w).*repmat(w,1,size(x,2));  
0055             [V,D]=eig(xx'*xx); V=real(V); D=real(D); D=diag(D);
0056             topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0057             xxx=xx*topcs;
0058             b=(yy'*xxx) / (xxx'*xxx); b=b';
0059         end
0060         if nargout>1; z=nt_demean(x,w)*topcs*b; z=z+mn; end
0061     else
0062         %% each channel has own weight
0063         if size(w,2) ~= size(y,2); error('!'); end 
0064         if nargout; z=zeros(size(y)); end
0065         for iChan=1:size(y,2)
0066             if sum(w(:,iChan))==0; %disp(iChan);
0067                 %warning('weights all zero');
0068                 b=zeros(size(y,2));
0069             else
0070                 yy=nt_demean(y(:,iChan),w(:,iChan)) .* w(:,iChan); 
0071                 x=nt_demean(x,w(:,iChan)); % remove channel-specific-weighted mean from regressor
0072                 xx=x.*repmat(w(:,iChan),1,size(x,2)); 
0073                 [V,D]=eig(xx'*xx); V=real(V); D=real(diag(D));
0074                 topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0075                 xxx=xx*topcs;
0076                 b(iChan,1:size(topcs,2))=(yy'*xxx) / (xxx'*xxx); 
0077             end
0078             if nargout>1; z(:,iChan)=x*(topcs*b(iChan,1:size(topcs,2))') + mn(:,iChan); end
0079         end
0080     end             
0081 end
0082 
0083 %%
0084 if nargout>1;
0085     z=nt_fold(z,m);
0086 end
0087 
0088 %% test code
0089 if 0
0090     % basic
0091     x=randn(100,10); y=randn(100,10); 
0092     b1=nt_regw(x,y); b2=nt_regw(x,x); b3=nt_regw(x,y,ones(size(x))); 
0093     figure(1); subplot 131; nt_imagescc(b1); subplot 132; nt_imagescc(b2); subplot 133; nt_imagescc(b3);
0094 end
0095 if 0
0096     % fit random walk
0097     y=cumsum(randn(1000,1)); x=(1:1000)'; x=[x,x.^2,x.^3];
0098     [b,z]=nt_regw(y,x); 
0099     figure(1); clf; plot([y,z]);
0100 end
0101 if 0
0102     % weights, random
0103     y=cumsum(randn(1000,1)); x=(1:1000)'; x=[x,x.^2,x.^3];
0104     w=rand(size(y));
0105     [b,z]=nt_regw(y,x,w); 
0106     figure(1); clf; plot([y,z]);
0107 end
0108 if 0
0109     % weights, 1st vs 2nd half
0110     y=cumsum(randn(1000,1))+1000; x=(1:1000)'; x=[x,x.^2,x.^3];
0111     w=ones(size(y)); w(1:500,:)=0;
0112     [b,z]=nt_regw(y,x,w); 
0113     figure(1); clf; plot([y,z]);
0114 end
0115 if 0
0116     % multichannel
0117     y=cumsum(randn(1000,2)); x=(1:1000)'; x=[x,x.^2,x.^3];
0118     w=ones(size(y)); 
0119     [b,z]=nt_regw(y,x,w); 
0120     figure(1); clf; plot([y,z]);
0121 end
0122

Generated on Tue 09-Oct-2018 10:58:04 by m2html © 2005