function [net trainacc]=ppinetwork(x,y,varargin)
% tworzy siec w oparciu o index PP
%
% TODO
% * dokonczyc opis opcji
% * co robic gdy neuron nie potrafi odeseparowac? Iloczyn neuronow,
% douczanie konfliktowych przypadkow w ortogonalnej przestrzeni
%
% DONE
% * kryterium stopu roznocej sieci - kiedy przestac? (done - zatrzymuje gdy nie rosnie dokladnosc, lub gdy juz nie ma nic do uczenia)

param = inputParser;
% data
param.addRequired('x',@isnumeric);
param.addRequired('y',@isnumeric);

% inner function
param.addOptional('function','f_x4',@(x)any(strcmpi(x,{'triangle','f_x4','fx4','bicentral'})));
param.addOptional('beta',2, @(x)isnumeric(x) && x>0);

% optimalization global
param.addParamValue('learningRate',0.1,@(x)isnumeric(x) && x >= 0);
param.addParamValue('eps',0.001,@(x)isnumeric(x) && x>0);
param.addParamValue('maxIterations',1000,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('initiations',5,@(x)isnumeric(x) && x > 0 && mod(x,1)==0);
param.addParamValue('checkPeriod',5,@(x)isnumeric(x) && x>0 && mod(x,1)==0);
param.addParamValue('stopCriterium',2,@(x)x==1|| x==2);

% optimalization - case depend
param.addParamValue('initWeights', [],@(x)isnumeric(x));
param.addParamValue('lambda', 0.1,@(x)isnumeric(x) && x > 0 );
param.addParamValue('ortoWeights',[],@isnumeric);
param.addParamValue('indGmax',[],@isnumeric);

% loging
param.addParamValue('logFile',[],@ischar);
param.addParamValue('dataName','data',@ischar);
param.addParamValue('save','none',@(x)any(strcmpi(x,{'none','all','last'})));
param.addParamValue('savedir',[],@ischar);
param.addParamValue('display','none',@(x)any(strcmpi(x,{'none','all','short'})));

% plotting
param.addParamValue('plot','none',@(x)any(strcmpi(x,{'none','all','ppi','last'})));

param.parse(x,y,varargin{:});


%fprintf('Input parameters:\n\n');
%disp(param.Results);
%disp(param.Parameters);
parameters = param.Results;
clear param;

%[vx fx]=size(x);

beta        = parameters.beta;
lrate       = parameters.learningRate;	% learning rate (step of gradnient descent)
eps         = parameters.eps;	% 
nmax        = parameters.maxIterations;     % nax. number of iterations
%pplot       = 0;
%lastppi     = -1;
% ffplot      = 0;
ninit       = parameters.initiations;
%ww          = parameters.initWeights;
%dataname    = parameters.dataName;
%avgtest     = parameters.checkPeriod;
%stopcriterium = parameters.stopCriterium;
%iGmax       = parameters.indGmax;
%wort        = parameters.ortoWeights;
%lambda      = parameters.lambda;
%orto        = ~isempty(wort);
funcname    = parameters.function;
%wo          = 0;
%saveall     = 0;
%savelast    = 0;
%savedir     = strcat('ppi-results-',datestr(now,'yyyy.mm.dd'));
%procedure   = '';

ksi = 0.95;  % procent poprawnych wektorow w klastrze
nnn=2;
%bestw = [];
%bestppi = [];
%bestinit = 0;
%bestn = 0;
%bestigmax = -1;

%if ~isempty(parameters.savedir)
%    savedir = parameters.savedir;
%end
switch parameters.display
    case 'all'
        display = 2;
    case 'short'
        display = 1;
    case 'none'
        display = 0;
end

k= 1;
learning = 1;

ory = y;
orx = x;

[vectorsCount features]= size(x);
[labels a index] = unique(y);
labelsCount = size(labels,1);
labelsIndex = zeros(vectorsCount,labelsCount); % macierz IxJ zawierajaca 1 gdy I-ty wektor nalezy do klasy J-tej, w przeciwnym razie 0
for i=1:labelsCount
    labelsIndex(:,i) = (index == i);
end

lasttrainacc = 0;

net = cell(1,1);
while learning
 %   labelsPerClassCount=sum(labelsIndex);

%    fprintf('label '); fprintf('%d ',labels); fprintf('\n'); fprintf('count ');fprintf('%d ',labelsPerClassCount); fprintf('\n');
  %  clf;
    
    [wb ppib ib]=qpctrain(x,y,'learningRate',lrate,'initiations',ninit,'eps',eps,'beta',beta,'maxiterations',nmax,'function',funcname,'plot','none','save',parameters.save,'display','short');
     
    if display > 1
       fprintf('Best direction: I=%6.3f [%s ]\n',ppib(1),sprintf(' %6.3f',wb(1,:)));
    end

    %    [wb ppib ib]=ppoptimize4(x,y,'learningRate',lrate,'initiations',ninit,'eps',eps,'beta',beta,'maxiterations',nmax,'function',funcname,'plot',parameters.plot,'save',parameters.save,'display',parameters.display);
    figure(1);
    clf;
    [pc node ic cl]=clusteroptimize(x,y,wb(1,:),'indGmax',ib(1),'plot','all','display','short');
    
    figure(nnn);
    clf;
    bgraph3(x*node.w',y,'borders',[node.a node.b]);
    nnn=nnn+1;
        print('-depsc',sprintf('%d.eps',nnn));
        
    while (node.np(node.n) / node.nn(node.n)) < ksi
        x2 =x(cl,:);
        y2 = y(cl);

        inew=find(cl == ib(1));
        if inew == 0
            inew = [];
        end
        ww=2.0*rand(1,features)-1;
        [pc node2 ic2 cl2]=clusteroptimize(x2,y2,ww,'indGmax',inew,'plot','none','display','short');
        node.np = [node.np; node2.np ];
        node.nn = [node.nn; node2.nn ];
        node.npall = [ node.npall; node2.npall];
        node.all = [node.all; node2.all];
        node.n = 1 + node.n;
        
        figure(nnn);
        clf;
        bgraph3(x2*node2.w',y2,'borders',[node2.a node2.b]);
        print('-depsc',sprintf('%d.eps',nnn));
        nnn=nnn+1;
        
        node.a = [node.a; node2.a];
        node.b = [node.b; node2.b];
        node.w = [ node.w; node2.w ];
        ff = node.func;
        ff{node.n} = node2.func{1};
        node.func = ff;

%     z=node.n;
 %   fprintf('%2dth node (%d) %4.1f%% %5.3f [%s ] (%6.3f,%6.3f) +%d/%d -%d/%d %4.1f%%  %4.2f\n',k-1,node.label,0,pc,sprintf(' %6.3f',node.w(z,:)),node.a(z),node.b(z),node.np(z),node.npall(z),node.nn(z)-node.np(z),node.all(z)-node.npall(z),100*node.np(z)/node.all(z),node.np(z)/node.nn(z));  
        cl = cl(cl2);
%        disp(cl')
    end
    select=ones(vectorsCount,1);
    select(cl) = 0;
    ic = find(select);

    
    net{k}=node;
    k = k+ 1;    
    x = x(ic,:); 
    y = y(ic);
    labelsIndex=labelsIndex(ic,:);
    labelsPerClassCount=sum(labelsIndex,1);
    vectorsCount = length(ic);
    
    currnode = lastnode(x,y);
    currnode.n = 1;
    currnode.np = labelsPerClassCount(labels == currnode.label);
    currnode.nn = vectorsCount;
    currnode.npall = currnode.np;
    currnode.all = vectorsCount;
    
    net{k}=currnode;
 
 %   fprintf('Cuurent network\n');
%    net{:}

    currtrainacc=nettest(orx,ory,net);
    
    if display > 0
            for z=1:node.n
                fprintf('%2dth node (%d) %4.1f%% %5.3f [%s ] (%6.3f,%6.3f) +%d/%d -%d/%d %4.1f%%  %4.2f\n',k-1,node.label,currtrainacc,pc,sprintf(' %6.3f',node.w(z,:)),node.a(z),node.b(z),node.np(z),node.npall(z),node.nn(z)-node.np(z),node.all(z)-node.npall(z),100*node.np(z)/node.all(z),node.np(z)/node.nn(z));
            end
    end
    % fprintf('Adding %dth node, accuracy %f \n',k-1,currtrainacc);
    
    if sum(labelsPerClassCount > 0) == 1
        break
    end

    if currtrainacc <= lasttrainacc
        net{k-1} = lnode;
        net=net(1:k-1);
%        fprintf('Removing bad neuron\n')
        break
    end
    lnode = currnode;
    lasttrainacc = currtrainacc;
end


 

trainacc=nettest(orx,ory,net);
%fprintf('Koncowa siec\n'); 
%net{:}

if display > 0
    fprintf(' Train acc. %6.2f\n',trainacc);
end

function node=lastnode(x,y)
        node.w=zeros(1,size(x,2));
        node.label = mode(y);
        node.a=-1;
        node.b=1;
        node.func{1}=@(x)ones(size(x,1),1);
