function [bestw bestp besterr bestqpc]=llvqpctrain(x,y,varargin)
% learning GLVQ + QPC -> Projection persut schema


% PPOPTIMIZE runs optimalization proces of PPI function for given dataset X and labels Y.
%
%   W = PPOPTIMIZE4(X,Y,PARAMETERS,VALUE,...)
%	searching for optimal W witch maximize PP index value
%	X : dataset
%	Y : labels
%
%	posible PARAMETERS are:
%	   Function:
%		'function' - 'gauss', 'triangle','fx4' or 'bicentral' are recognized
%		           as f_triangular, f_x4 and f_bicentral function accordinly (default f_gauss)
%		'beta'     - value of BETA parameter - width of given 'function'
%		           (default .1, seams to be good value for normalized data
%		           and f_gauss, for f_x4 try beta=2, for f_bicentral
%		           beta=?)
%
%     Optimalization:
%		'maxIterations'	  - maximum number of iterations (default 1000)
%       'initiations'     - number of repetitions of optimalization, best
%                       solusion over those initiations is choosen at the
%                       end of training
%		'learningRate'    - learning rate for gradnient descent procedure (default 0.1)
%		'initWeights'     - initial weights (staring point)
%       'stopCriterium'   - choose criterium for ending of optimalization (default 2), 
%                         currently implemented are:
%                         1 - check difference beetwen two last values,
%                         if |PPI(t)-PPI(t-1)| < eps then stop
%                         2 (default) check difference beetwen average
%                         values over 'checkPeriod' last values
%                         
%       'checkPeriod'     - if 'stopCriterium' = 2 then optimalization
%                         stops when average value of qpcfunction over checkPeriod last
%                         values is considered. If difference beetwen two last average values
%                         is less (default: 5)
%       'eps'             - set eps value (default 0.001), for given stoping
%		                  criterium define max. difference between values of (or average values over checkPeriod) PPIFUNCTION
%       
%     Searching for ortogonal solusion:
%        'ortoWeights'    - weights for witch ortgonal solution will be searched  (default [])
%        'lambda'         - value of LAMBDA parameter controls an influence of ortogonalization process 
%
%     Other optimalization parameters;
%        'iGmax'          - index of input vector for with maximum value of G
%                         function will be searched, required in secound
%                         stage of learning (under development)
%
%     Plotting:
%		 'plot'           - currently implemented plots: (default 'none') 
%                        'all' - plots all on single figure for each iteration, 
%                        'ppi' - plots PPIFUNCTION for each iteration,
%                        'last' - plots final projection 
%      
%     Results saving:
%        'log'            - 'off' or 'on' (default 'off')
%        'savedir'        - name of directory to store logfiles and
%                         pictures created during learning process.
%                         Default dir name is: 'yy.mm.dd-ppi-results',
%                         where yy.mm.dd is the current date.
%        'save'        -  'all', 'last' or 'none' - saves  pictures in SAVEDIR with
%                          apropriate names (default 'none'). Note thet if SAVEALL = YES then 
%                         option PLOT is automatic set to ALL.
%        'display'        - 'none', 'short', 'all' degree of informations  displayed on screan  (default 'short') 
%        'logFile'        - name of file with stored results (default: data.log if no dataname is given)
%        'dataName'       - name of dataset, used for names in output files
%                         (logs and pictures), (default: 'data')
%
%	e.g.   w=ppoptimize(x,y,'function','triangular','beta',3,'plot','ppi')
%		use triangular function with beta = 3 and plot change of ppi function douring learning (function width)
%
%
% !!! Note that ppoptimize4 is an older version of this one.
%
% TODO : 
%   * document all parameters and options (almost done)
%   * implementation of diferent optimalization methods: e.g. sim. aneling
%   (see MATLAB docs)
%   * simple gui for controling and visualization of optimalization
%   * new pictures - G values, weights values as bar plot, ... 
%   * improve log saving and results displaing   
%
param = inputParser;
% data
param.addRequired('x',@isnumeric);
param.addRequired('y',@isnumeric);

% inner function
param.addOptional('function','f_gauss',@(x)any(strcmpi(x,{'triangle','f_x4','fx4','bicentral','gauss'})));
param.addOptional('beta',0.1, @(x)isnumeric(x) && x>0);

% optimalization global
param.addParamValue('learningRate',0.1,@(x)isnumeric(x) && x >= 0);
param.addParamValue('eps',0.001,@(x)isnumeric(x) && x>0);
param.addParamValue('maxIterations',1000,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('initiations',5,@(x)isnumeric(x) && x > 0 && mod(x,1)==0);
param.addParamValue('checkPeriod',5,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('stopCriterium',2,@(x)x==1|| x==2);
param.addParamValue('directions',2,@(x)isnumeric(x) && x>0 && mod(x,1)==0);

% optimalization - case depend
param.addParamValue('initWeights', [],@(x)isnumeric(x));
param.addParamValue('lambda', 0.1,@(x)isnumeric(x) && x > 0 );
param.addParamValue('ortoWeights',[],@isnumeric);
param.addParamValue('indGmax',[],@isnumeric);
param.addParamValue('orthogonalizationMethod','projection',@(x)any(strcmpi(x,{'projection','error'})));

% loging
param.addParamValue('log','off',@(x)any(strcmpi(x,{'on','off'})));
param.addParamValue('logFile',[],@ischar);
param.addParamValue('dataName','data',@ischar);
param.addParamValue('save','none',@(x)any(strcmpi(x,{'none','all','last'})));
param.addParamValue('savedir',[],@ischar);
param.addParamValue('display','short',@(x)any(strcmpi(x,{'none','all','short'})));

% plotting
param.addParamValue('plot','none',@(x)any(strcmpi(x,{'none','all','ppi','last'})));
%
param.parse(x,y,varargin{:});
%fprintf('Input parameters:\n\n');
%disp(param.Results);
%disp(param.Parameters);

%directions  = param.Results.directions;
%ortmethod   = param.Results.orthogonalizationMethod;

[nv nf]=size(x);

%bestw = zeros(directions,fx);
%bestqpc = zeros(directions,1);

%if directions > fx; directions = fx ; end 

%xp=x;

%x1=zeros(vx,directions);

lvqlr = 0.05;
proj=nf;
logfilename='logfile';

[prototypes lvqerr]=glvqtrain(x,y,lvqlr);
%acc=100*(1-lvqerr)

k=size(prototypes,1);

currw = zeros(proj,nf);
currerr = zeros(proj,1);
currp = zeros(k,proj+1);
%serr=zeros(proj,1);   % single projection error
I=eye(nf);
lasterr = 1.0;
bestpi=1;
% initp=prototypes;
prot = prototypes;

for pi=1:proj  % number of projections
    fprintf('Projection %d\n',pi);
%    prefix = sprintf('%s.D%d',pref,pi);
    
    if pi == 1
        P = [];
        xp = x;
%        initp=prototypes;
    else
         P=I-currw'*currw;  % operator projekcji na podprz. otronormalna
         xp=x*P;        
         initp=[prot(:,1:end-1)*P prot(:,end)];
         prot=llvqtrain(xp,y,'lvqlr',lvqlr,'linear','no','initp',initp);
    end
    [wb bestqpc]=qpcoptimize2(xp,y,prot,param,logfilename);

    wb=wb/norm(wb);
    if pi > 1
        w=wb*P;
    else
        w=wb;
    end

    currw(pi,:)=w;
    currp(:,pi)=prot(:,1:end-1)*wb'; 
    currerr(pi)=lvqerror(x*currw(1:pi,:)',y,[currp(:,1:pi) prot(:,end)]);
    acc=100*(1-currerr(1:pi))
%     serr(pi)=err;
   if pi > 1 && lasterr <= currerr(pi)
        bestpi=pi-1;
        fprintf('STOP: Next projection give worse results\n');
        break;
    end
    lasterr = currerr(pi);
end
bestw=currw(1:bestpi,:);
bestp=[currp(:,1:bestpi) prot(:,end)];
besterr=currerr(1:bestpi);
    
% moze zrobic usuwanie niepotrzebnych, na linii to proste, w wiekszej p-ni
% mozna uzyc CNN, lub innej metody redukcji 




%lacc=100*(1-lvqerror(x*bestw',y,[prototypes(:,1:end-1)*bestw' prototypes(:,end)]))
    
%    [wb prot err]=findlvqbestprojection(xp,y,k,lvqlr,init,prefix,initp);

    
 



return


for dirCount = 1:directions
    logfilename = sprintf('%s.qpc.%s.%0.1f',param.Results.dataName,param.Results.function,param.Results.beta);

    if dirCount == 1
        P =  [];
        xp = x;
    else
        P=eye(fx)-bestw'*bestw;  % operator projekcji na podprz. otronormalna
        xp=x*P;        
    end
    [bw bqpc]=qpcoptimize(xp,y,param,sprintf('%s.D%d',logfilename,dirCount));
    
    if dirCount > 1
        bw=bw*P;
    end
    bestw(dirCount,:)=bw/norm(bw);
    bestqpc(dirCount)=bqpc;

    if strcmp(param.Results.save,'all') || strcmp(param.Results.save,'last')
        x1(:,dirCount)=x*bw';
        for f=2:dirCount
            clf;
            scaterplot(x1,y,[f-1 dirCount]);
            saveplot(sprintf('%s.scatterplot.D%d-%d',logfilename,f-1,dirCount));
        end
    end
end

if directions > 1 && (strcmp(param.Results.save,'all') || strcmp(param.Results.save,'last'))
    clf;
    scaterplot(x*bestw',y,1:directions);
    saveplot(sprintf('%s.scatterplot.all',logfilename));
    plik=fopen(sprintf('%s.weights.all.log',logfilename),'w');
    for i=1:directions
        fprintf(plik,sprintf('%2d ',i));
        fprintf(plik,sprintf('QPC=%6.4f ',bestqpc(i)));
        fprintf(plik,strcat('w=',sprintf(' %6.4f',bestw(i,:))));
        fprintf(plik,'\n');
    end
    fclose(plik);
    
end


% one direction search
function [bestw bestppi bestigmax]=qpcoptimize2(x,y,prototypes,parameters,logfilename)

[vx fx]=size(x);
px=prototypes(:,1:end-1);
py=prototypes(:,end);

beta        = parameters.Results.beta;           
lrate       = parameters.Results.learningRate;	% learning rate (step of gradnient descent)
eps         = parameters.Results.eps;	% 
nmax        = parameters.Results.maxIterations;     % nax. number of iterations
%pplot       = 0;
%lastppi     = -1;
plotall     = 0;
plotlast    = 0;
iplot       = 0;
% ffplot      = 0;
ninit       = parameters.Results.initiations;
ww          = parameters.Results.initWeights;
dataname    = parameters.Results.dataName;
avgtest     = parameters.Results.checkPeriod;
stopcriterium = parameters.Results.stopCriterium;
iGmax       = parameters.Results.indGmax;
wort        = parameters.Results.ortoWeights;
lambda      = parameters.Results.lambda;
orto        = ~isempty(wort);
funcname    = parameters.Results.function;
%wo          = 0;
saveall     = 0;
savelast    = 0;
savedir     = strcat(datestr(now,'yy.mm.dd'),'.qpc_results');
%procedure   = '';
log         = 0;

bestw = [];
bestppi = [];
bestinit = 0;
bestn = 0;
bestigmax = -1;

switch parameters.Results.display
    case 'all'
        display = 2;
    case 'short'
        display = 1;
    case 'none'
        display = 0;
end

switch parameters.Results.log
    case 'on'
        log = 1;
    case 'off'
        log = 0;
end

switch funcname
    case { 'gauss' , 'f_gauss'}
        func = @(xx)f_gauss(xx,beta);
    case { 'fx4' , 'f_x4'}
        func = @(xx)f_x4(xx,beta);
    case 'triangle'
        func = @(xx)f_triangular(xx,beta);
    case 'bicentral'
        func = @(xx)f_bicentral(xx,beta,0,10);
    otherwise
        error('Ten blad nie powinien wystapic - ale jesli wystapil to znaczy, ze jest cos nie tak z podana funkcja');
end

switch parameters.Results.plot
    case 'all'
        plotall = 1;
    case 'ppi'
        iplot = 1;
    case 'last'
        plotlast = 1;
end

if ~isempty(parameters.Results.savedir)
    savedir = parameters.Results.savedir;
end

switch parameters.Results.save
    case 'all'
        plotall = 1;
        saveall = 1;
        log = 1;
    case 'last'
        plotlast = 1;
        savelast = 1;
        log = 1;
    case 'iplot'
        iplot = 1;
end

if isempty(iGmax) 
    if orto == 0
        procedure = 'ppi';
        ppifun = @(wx)qpcfunction(x,y,wx,[px py],func);    % podstawowa postac indeksu
    else
        procedure = strcat('ppi-orto',sprintf('.%0.1f',lambda));   % indeks 
        wort=wort/norm(wort);
        ortoproj = x*wort';
        ppifun = @(wx)qpcfunction(x,y,wx,[],func,wort,lambda);
    end
else 
    procedure = strcat('ppi-gmax',sprintf('.%d',iGmax));
    ppifun = @(wx)ppigmaxfunction2(x,y,wx,iGmax,func);
end


text = '';
if log == 1 || display > 0
    text = [sprintf('procedure     = %s\n',procedure)...     
            sprintf('dataname      = %s\n',dataname)...
            sprintf('vectors       = %d\n',vx)...        
            sprintf('features      = %d\n',fx)...            
            sprintf('learningRate  = %f\n',lrate)...
            sprintf('eps           = %f\n',eps)...           
            sprintf('initiations   = %d\n',ninit)...
            sprintf('maxIterations = %d\n',nmax)...          
            sprintf('ppi function  = %s\n',func2str(ppifun))...
            sprintf('function      = %s\n',func2str(func))...
            sprintf('beta          = %f\n',beta)...          
        ];
    if orto == 1
        text = [text 'ortogonal W   = ' sprintf('%f  ',wort) '\n'...
                sprintf('lambda        = %f\n',lambda)...
        ]; 
    end
    if ~isempty(iGmax)
        text = [text sprintf('Gmax index    = %d\n',iGmax)];
    end
end

if log == 1
    [s comment]=mkdir(savedir);
    if s == 0
        error(comment);
    end
%    logfilename = parameters.Results.logFile;
    if isempty(logfilename)
        prefix = strcat(savedir,'/',dataname,sprintf('.%s-%s.%0.1f',procedure,funcname,beta));
        logfilename = strcat(prefix,'.log');
    else
        prefix=strcat(savedir,'/',logfilename);
        logfilename = strcat(prefix,'.log');
    end
    logfile = fopen(logfilename,'wt');
    if logfile == -1
        error('Error opening %s file\n',logfilename);
    end
    fprintf(logfile,text);
end

if display > 0
    fprintf(text);
end

%%%%%%%%%%%%%%   OK - tu zaczyna sie zabawa %%

for initcount=1:ninit
	n = 0;
	ppitable = zeros(1,nmax+1);

    if display > 0 || log == 1
        text = sprintf('\nInitialization %d of %d\n',initcount,ninit);
        if display > 0
            fprintf(text);
        end
        if log == 1
            fprintf(logfile,text);
        end
    end
    
    if isempty(ww)
        w = rand(1,fx)*2-1;       % random initialization [-1,1]
    else
		w = ww;
    end

    avgppi = 0;
    lastavgppi = 0;
        
    while ( 1 )
        if (mod(n,avgtest) == 0 && n > 0)
            lastavgppi = avgppi;
            avgppi = mean(ppitable(n-avgtest+1:n));
        end;

        w=w/norm(w);
%        [ppi ppid projection G] = ppifun(w);
        [ppi ppid projection] = ppifun(w);
		
        if nargout > 2
            [gmax igmax] = max(G);
        end
        
        lastppi = ppi;
		ppitable(1,n+1) = ppi;
	
        if iplot == 1 && n > 0
			plotppi(ppitable,n,avgtest);
            drawnow;
        end

        if orto == 1
            w1 = w*wort';
        end

        if display > 1 || (display == 1 && n == 0)
            fprintf('%3d  %10.6f   ',n,ppi);   fprintf('  %6.4f',w);
            if orto == 1
                fprintf('  [ %6.4f ] ',w1);
            end
            fprintf('\n');
        end
        if log == 1
            fprintf(logfile,'%3d  %10.6f   ',n,ppi);   fprintf(logfile,'  %6.4f',w);
            if orto == 1
                fprintf(logfile,'  [ %6.4f ] ',w1);
            end
            fprintf(logfile,'\n');
        end
        
        if plotall == 1
            str=cell(1,3);
            clf;
            set(gcf,'Color','w');
            if orto == 0
                
%                str{1} = strcat('w = ',sprintf(' %.2f ',w));
                str{2} = strcat('QPC = ',sprintf(' %.5f ',ppitable(n+1)));
                str{3} = strcat('N = ',sprintf(' %d ',n));
 
                bgraph3(projection,y,'position',[0.05 0.4 0.9 0.50],'sigma',0.01,'function',func,'str',str);
            else
                str{2} = strcat('\alpha = ',sprintf(' %.4f ',w1));
                str{1} = strcat('QPC = ',sprintf(' %.4f ',ppitable(n+1)));
                str{3} = strcat('N = ',sprintf(' %d ',n));
                
                
                scaterplot([projection ortoproj],y,'str',str);
                
%                set(gca,'position',[0.1 0.45 0.85 0.50]);
%                xlabel(strcat('w2 = ',sprintf(' %.2f ',w)));
 %               ylabel(strcat('w1 = ',sprintf(' %.2f ',wort)));
                 xlabel(strcat('w2'));
                 ylabel(strcat('w1'));

            end
            axes('position',[0.1 0.1 0.85 0.25]);
            plotppi(ppitable,n,50);
            drawnow();
            
            if saveall == 1
                saveplot(strcat(prefix,sprintf('.i%d.frame%04d',initcount,n)));
            end
        end;
        
        if (stopcriterium == 1 && (abs((ppi - lastppi)/lastppi) < eps) && n > 1 ) 
            break;
        end;
        if (stopcriterium == 2 && abs((avgppi - lastavgppi)/lastavgppi) < eps && n > 1 ) 
            break;
        end;
        if ( n >= nmax ) 
            break;
        end
        
        w = w + lrate * ppid;
        n = n + 1;
        
    end;
	
    if (display == 1)  % if display == 1 then display only last result
        fprintf('%3d  %10.6f   ',n,ppi); 
        fprintf('  %6.4f',w); 
        if orto == 1
                fprintf('  [ %6.4f ] ',w1);
        end
        fprintf('\n');
    end
    
	if plotlast == 1
            str=cell(1,3);
            clf;
            set(gcf,'Color','w');
            if orto == 0
                
%                str{1} = strcat('w = ',sprintf(' %.2f ',w));  % !! dla duzej ilosci wag wychodzi sieczka
                str{1} = strcat('QPC = ',sprintf(' %.5f ',ppitable(n+1)));
                str{2} = strcat('N = ',sprintf(' %d ',n));

                bgraph3(projection,y,'sigma',0.01,'function',func,'str',str);
 %               bgraph3(projection,y,'position',[0.05 0.4 0.9 0.50],'sigma',0.01,'function',func,'str',str);
            else
                str{2} = strcat('\alpha = ',sprintf(' %.4f ',w1));
                str{1} = strcat('QPC = ',sprintf(' %.4f ',ppitable(n+1)));
                str{3} = strcat('N = ',sprintf(' %d ',n));
                
                
                scaterplot([projection ortoproj],y,'str',str);
                
 %               set(gca,'position',[0.1 0.45 0.85 0.50]);
                xlabel(strcat('w2'));
                ylabel(strcat('w1'));
                
            end
%            axes('position',[0.1 0.1 0.85 0.25]);
%            plotppi(ppitable,n,50);
            drawnow();
            
            if savelast == 1
                saveplot(strcat(prefix,sprintf('.i%d.last',initcount)));
            end
	end;
	
	if (initcount == 1) || (ppi > bestppi) 
		bestw = w;
		bestppi = ppi;
		bestinit = initcount;
		bestn = n;
        if nargout > 2
            bestigmax = igmax;
        end
	end
end

if display > 0 || log == 1
    text = ['---------------------------------------------\n' ...
            sprintf('\nInitialization  %d was the best\n',bestinit)...
            sprintf('%d  %6.4f    ',bestn,bestppi)...
            sprintf('  %6.4f',bestw) sprintf('\n')];
    if display > 0
        fprintf(text);
    end
    if log == 1
        fprintf(logfile,text);
        fclose(logfile);
    end
end


function saveplot(prefix)
 	name = strcat(prefix,'.png');
 	print('-dpng','-r96',name);
% % octave only
% %	print('-dpng','-S640,480',name);


%    name = strcat(prefix,'.eps');
%  	 print('-depsc',name);

function plotppi(ppitable,n,scalefactor)
    if nargin < 3
        scalefactor = 10;
    end
    if nargin < 2
        n = length(ppitable);
    end
    xl = scalefactor.*(fix(n./scalefactor)+1);
    cla;
    hold on;
    xlim([0 xl]);
    xlabel('Iterations');
    ylabel('PPI value');
    plot(0:n,ppitable(1:n+1),'-b','LineWidth',2);
    box on;
 	hold off;
