/**
 * Trainer (Header+Implementation)
 * 
 * Copyright 2013 Fabian Schrodt, FSchrodt@gmx.de
 * 
 * This file is part of RNNPBlib.
 * 
 * RNNPBlib is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License version 3 as published by the Free Software Foundation.
 * 
 * RNNPBlib is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along with RNNPBlib. If not, see http://www.gnu.org/licenses/.
 */

#pragma once

#include "RNNPB_Definitions.h"

#include <vector>
#include <iostream>
#include <fstream>
#include <time.h>
using namespace std;

#ifdef WIN32
//we need a win port for posix gettime:
#include "gettimefix.h"
#endif

#include "RNNPB_ApplicationInterface.h"
#include "RNNPB_NetworkContainer.h"
#include "RNNPB_Behaviour.h"
#include "RNNPB_MatlabWriter.h"

class RNNPB_CVTrainer
{
public:
	RNNPB_MatlabWriter writer;	//one writer per trainer

protected:
	RNNPB_ApplicationInterface* application;

	RNNPB_NetworkContainer* network;

	vector <RNNPB_Behaviour*> behaviours;

	int iteration;					//internal, has nothing to do with FeedbackInterface::iteration
	long unsigned int overall_iterations;		//epoch-global sum of iterations

	unsigned int selected_behaviour;		//behaviour-counter

	unsigned int epoch_now;				//epoch-counter
	//unsigned int generate_epochs;			//for subclass OfflineTrainer

	double error;					//for MSE

	//for scatter-plotting of 3D-functions:
	/*ofstream fileout;
	ofstream fileout_target;*/

	//call only once
	void init_behaviour(RNNPB_Behaviour* behaviour)
	{
		for(unsigned int p=0;p<(*network->concept_layers).size();p++)
		{
			behaviour->concept_vector.push_back(RNNPB_Vector((*network->concept_layers)[p]->size()));
			behaviour->concept_vector_update.push_back(RNNPB_Vector((*network->concept_layers)[p]->size()));

			//Initialize PB-vector:
			for(unsigned int i=0;i<(*network->concept_layers)[p]->size();i++)
			{
				#ifdef ENABLE_DEBUG
				if((*network->concept_layers)[p]->neuron[i]->outgoing_weights.size()>1) //there should be only one weight
				{
					cout<<"Error: Only one weight per parametric neuron allowed!\n";
					//__asm{int 3};
				}
				#endif

				switch((*network->concept_layers)[p]->neuron[i]->outgoing_weights[0]->weightType)
				{
				case ParametricBiasCV:
					behaviour->concept_vector[p][i] = DEFAULT_INIT_BIAS_CV;
					break;
				case SecondOrderCV:
					behaviour->concept_vector[p][i] = DEFAULT_INIT_SECONDORDER_CV;
					break;
				case ModulatedCV:
					behaviour->concept_vector[p][i] = DEFAULT_INIT_MODULATION_CV;
					break;
				default:
					cout<<"Weight type inappropriate for parametric layer!\n";
					break;
				}

				behaviour->concept_vector_update[p][i]=0.0;
			}

			cout<<"Concept Vector: ";
			for(unsigned int i=0;i<(*network->concept_layers)[p]->size();i++)
				cout<<behaviour->concept_vector[p][i]<<" ";
			cout<<endl;
		}
	}

	void learn(unsigned int behaviour)
	{
		//adapt weights:
		network->learnWeights.func();

		for(unsigned int p=0;p<(*network->concept_layers).size();p++)
		{
			behaviours[behaviour]->concept_vector[p] = (*network->concept_layers)[p]->swapWeights();

			//get weight update for momentum
			for(unsigned int i=0;i<behaviours[selected_behaviour]->concept_vector_update[p].size;i++)
				behaviours[behaviour]->concept_vector_update[p][i] = (*network->concept_layers)[p]->neuron[i]->outgoing_weights[0]->weightUpdate;

			#ifdef ENABLE_EXPERIMENT
				//write out CV and momentum after learning
				stringstream exp;
				exp << "CV-phase"<<application->get_phase()<<"-b" << behaviour << "-p" << p;
				writer.write(exp, &(behaviours[behaviour]->concept_vector[p]));
				exp.str("");
				exp << "CV-MOMENTUM-phase"<<application->get_phase()<<"-b" << behaviour << "-p" << p;
				writer.write(exp, &(behaviours[behaviour]->concept_vector_update[p]));
			#endif
		}		
	}

	/*
	 * For offline-trainer only
	 * */
	virtual unsigned int getMaxEpoch(unsigned int startb, unsigned int stopb)
	{
		return 0;
	}

	/*
	 * For offline-trainer only
	 * */
	virtual unsigned int getMaxEpochForBehaviour(unsigned int behaviour)
	{
		return 1;
	}

	/*
	 * For offline-trainer only
	 * */
	virtual bool abortEpochConditionMet()
	{
		return false;
	}


	#ifdef ENABLE_RUNTIME
		double start_time, end_time;
		long int last_iterations;
	#endif

	void run(bool testing, unsigned int nr_of_runs, unsigned int printOutIterations, int behavior = -1)
	{
		iteration=-1;

		#ifdef ENABLE_RUNTIME
			//struct timespec timetmp;
			#ifdef WIN32
			struct timeval timetmp;
			#else
			struct timespec timetmp;			
			#endif
			last_iterations=-1;
		#endif

		for(unsigned int i=0;i<nr_of_runs && !(application->abortLearningConditionMet());i++)
		{
			if(i%printOutIterations==0)
			{
				cout<<"\nRuns: "<<i<<" ; Samples: "<<overall_iterations/1000<<" k"<<endl;

				#ifdef ENABLE_RUNTIME
				if(!testing)
				{						
					if(last_iterations>=0)
					{		
						clock_gettime(CLOCK_MONOTONIC, &timetmp);
						end_time=timetmp.tv_sec;
						#ifdef WIN32						
						end_time += (double) timetmp.tv_usec / 1000000.0;
						#else
						end_time += (double) timetmp.tv_nsec / 1000000000.0;						
						#endif

						cout << (double)(end_time-start_time) << " seconds passed.\n";
						cout << "Samples per second: " << ((double)(overall_iterations-last_iterations))/((double)(end_time-start_time)) << "; Connections per second: " << ((double)(overall_iterations-last_iterations)*((double)network->weightSize()))/(1000000.0*(double)(end_time-start_time)) << " million" << endl;
					}
					last_iterations=overall_iterations;

					clock_gettime(CLOCK_MONOTONIC, &timetmp);
					start_time=timetmp.tv_sec;
					#ifdef WIN32						
					start_time += (double) timetmp.tv_usec / 1000000.0;
					#else
					start_time += (double) timetmp.tv_nsec / 1000000000.0;
					#endif
				}
				#endif
			}

			unsigned int startb, stopb;
			if(behavior!=-1)
			{
				startb = behavior;
				stopb = behavior+1;
			}
			else
			{
				startb = 0;
				stopb = behaviours.size();
			}

			RNNPB_Vector mean_MSE(stopb-startb);
			RNNPB_Vector bepochs(stopb-startb);
			mean_MSE.clear();
			bepochs.clear();

			unsigned int max_epoch = getMaxEpoch(startb,stopb);

			if(testing==true)
				max_epoch=0;

			bool epoch_failed=false;

			for(unsigned int t=0;t<=max_epoch;t++)
			{
				for(unsigned int b=startb;b<stopb;b++)		//all behaviours
				{
					if(getMaxEpochForBehaviour(b)<t)
						continue;

					epoch_now=t;

					application->switch_feedback_method_to(b);
					setSelectedBehaviour(b, !testing && !epoch_failed);	//this also learns the weights...

					epoch_failed=false;

					application->initEpoch();

					while(!application->abortEpochConditionMet() && (!abortEpochConditionMet() || testing))
					{
						iterate(testing);
					}

					application->stopEpoch();

					epoch_failed=!(application->epochGoalReached());

					double MSE=getMSErrorSinceBehaviourStart();
					mean_MSE[b-startb]+=MSE;

					bepochs[b-startb]++;

					if(i%printOutIterations==0)
					{
						cout<<"Mean Squared Error for behaviour "<<b<<" (before training): "<<MSE<<endl;
					}

					/*
					 * Disabled.
					#ifdef ENABLE_EXPERIMENT
						//MSE nach jeder epoche ausschreiben
						stringstream exp;
						exp << "MSE-phase" << application->get_phase() << "-b" << b << "-t" << t;
						writer.write(exp, MSE);
					#endif
					*/

					if(epoch_failed)
						b--;	//repeat behaviour/epoch...
				}
			}

			#ifndef ENABLE_EXPERIMENT
			if(i%printOutIterations==0)
			#endif
			{
				if(i%printOutIterations==0)
					cout<<"----------------------------------------------------------------"<<endl;
				double meanmean_MSE=0;
				for(unsigned int b=startb;b<stopb;b++)
				{
					if(i%printOutIterations==0)
						cout<<"Mean MSE for behaviour "<<b<<" (before training): "<<mean_MSE[b-startb]/((double)(bepochs[b-startb]))<<endl;
					meanmean_MSE+=mean_MSE[b-startb]/((double)(bepochs[b-startb]));

					#ifdef ENABLE_EXPERIMENT
						//mean MSE for all epochs of one behaviour
						stringstream exp;
						exp << "MMSE-phase" << application->get_phase() << "-b" << b;
						writer.write<double>(exp, mean_MSE[b-startb]/((double)(bepochs[b-startb])));
					#endif
				}

				if(i%printOutIterations==0)
					cout<<"Mean MSE for all behaviours (before training): "<<meanmean_MSE/(stopb-startb)<<endl;

				#ifdef ENABLE_EXPERIMENT
					//mean MSE for all epochs of all behaviours
					stringstream exp;
					exp << "MMSE-phase" << application->get_phase() << "-bAll";
					writer.write<double>(exp, meanmean_MSE/(stopb-startb));
				#endif
			}
		}

		//learn last epoch
		setSelectedBehaviour(selected_behaviour, !testing);
	}

	/*
	 * Learn and switch behavior
	 * */
	void setSelectedBehaviour(unsigned int set_behaviour, bool learn_now = true)
	{
		//learn last behavior on switch
		if(iteration>=0)
		{
			if(learn_now)
				learn(selected_behaviour);
			/*else if(network->input->size()==1 && network->output->size()==1)
			{
				char filename[80];

				sprintf(filename,"network-scatter-plot-%d.txt",set_behaviour);
				fileout.close();
				fileout.open(filename);

				cout<<"Writing file '"<<filename<<"' for scatter-plotting"<<endl;

				sprintf(filename,"network-scatter-plot-target-%d.txt",set_behaviour);
				fileout_target.close();
				fileout_target.open(filename);

				cout<<"Writing file '"<<filename<<"' for scatter-plotting"<<endl;
			}*/

			error=0.0;
			iteration=-1;

			network->clear.func();
		}		

		selected_behaviour=set_behaviour;

		//switch concept-vectors
		for(unsigned int p=0;p<(*network->concept_layers).size();p++)
		{
			(*network->concept_layers)[p]->setInput(behaviours[selected_behaviour]->concept_vector[p]);

			/*cout<<"set pb input ";
			behaviours[selected_behaviour]->concept_vector[p].print();
			cout<<endl;*/

			//set weight update for momentum
			for(unsigned int i=0;i<behaviours[selected_behaviour]->concept_vector_update[p].size;i++)
				(*network->concept_layers)[p]->neuron[i]->outgoing_weights[0]->weightUpdate=behaviours[selected_behaviour]->concept_vector_update[p][i];

			/*cout<<"SET CONCEPT UPDATE: ";
			behaviours[selected_behaviour]->concept_vector_update[p].print();*/
		}
	}

	/*
	 * Returns feedback
	 * */
	virtual RNNPB_Vector feedNNandApp()
	{
		RNNPB_Vector sensor(network->input->size());
		application->get_nn_input(&sensor);
		network->input->setInput(sensor);

		RNNPB_Vector action = network->output->getActivation();
		application->action(action);

		RNNPB_Vector target(network->output->size());
		application->get_nn_target(&target);
		RNNPB_Vector feedback=target-action;
		network->output->setFeedback(feedback);

		#ifdef ENABLE_EXPERIMENT
			if(application->get_phase()>1)	//Write out every feedback! This uses a lot of disk space! (disable?)
			{
			//write out feedback for every sample:
			stringstream exp;
			exp << "Feedback-phase" << application->get_phase() << "-b" << selected_behaviour << "-t" << epoch_now;
			writer.write(exp, &feedback);
			}
		#endif

		return feedback;
	}

	void iterate(bool testing=false)
	{
		network->iterate.func();

		application->iterate();

		iteration++;

		overall_iterations++;


		RNNPB_Vector feedback(network->output->size());
		if(testing)
		{
			feedback = RNNPB_CVTrainer::feedNNandApp();
		}
		else
		{
			feedback = feedNNandApp();

			network->cumulateWeightUpdates.func();
		}

		/*else if(network->input->size()==1 && network->output->size()==1)
		{
			fileout<<network->input->getActivation()[0]<<", "<<network->output->getActivation()[0]<<endl;
			fileout_target<<sensor[0]<<", "<<target[0]<<endl;
		}*/	//TODO: Enable plotting using MATLAB-WRITER!

		for(unsigned int i=0;i<network->output->size();i++)
			error+=pow(feedback[i],2.0);
	}

public:

	RNNPB_Behaviour* get_behaviour(unsigned int b)
	{
		if(b<behaviours.size())
			return behaviours[b];
		else return NULL;
	}

	void addBehaviour()
	{
		behaviours.push_back(new RNNPB_Behaviour());

		init_behaviour(behaviours.back());
	}

	unsigned int getSelectedBehaviour()
	{
		return selected_behaviour;
	}

	unsigned int getEpochNr()
	{
		return epoch_now;
	}

	double getMSErrorSinceBehaviourStart()
	{
		return error/double(iteration+1);
	}

	/*
	 * behavior = -1 means all behaviors
	 * */
	void testNetwork(unsigned int runs_per_behavior, int behavior = -1, unsigned int print_runs = 10)
	{
		cout<<"\nRNNPB: Trainer: Testing RNNPB Network ";

		if(behavior != -1)
		{
			cout<<"for behavior "<<behavior<<"...\n";
		}
		else
		{
			cout<<"for all behaviors...\n";
		}

		if(application->phase < 2)	//set to default testing phase if not set otherwise
			application->phase=2;

		run(true, runs_per_behavior, print_runs, behavior);
	}

	void trainNetwork(unsigned int nr_of_runs, int behavior=-1, unsigned int print_runs = 10)
	{
		cout<<"RNNPB: Trainer: Learning RNNPB Network...\n";

		if(application->phase < 2)	//set to default testing phase if not set otherwise
			application->phase=1;

		run(false, nr_of_runs, print_runs, behavior);
	}

	RNNPB_CVTrainer
		(
				RNNPB_NetworkContainer* set_network,
				RNNPB_ApplicationInterface* set_application,
				unsigned int set_nr_behaviours
		)
	{
		application = set_application;

		iteration=-1;

		overall_iterations=0;

		error = 0.0;

		selected_behaviour=0;

		epoch_now=1;

		network = set_network;

		/*
		 * Parse PB-Layers
		 * */
		for(unsigned int i=0;i<(*network->concept_layers).size();i++)
		{
			(*network->concept_layers)[i]->setConstant(true);

			//set the weights that will be read out for concept vectors to appropriate learning rates/momentums
			/*for(unsigned int j=0;j<(*network->concept_layers)[i]->size();j++)
			{
				for(unsigned int k=0;k<(*network->concept_layers)[i]->neuron[j].outgoing_weights.size();k++)
				{
					(*network->concept_layers)[i]->neuron[j].outgoing_weights[k]->momentum=DEFAULT_PARAMETRIC_MOMENTUM;
					(*network->concept_layers)[i]->neuron[j].outgoing_weights[k]->learningRate=DEFAULT_PARAMETRIC_LEARNING_RATE;
				}
			}*/
		}

		/*
		 * Init concept-vectors
		 * */
		for(unsigned int i=0;i<set_nr_behaviours;i++)
		{
			behaviours.push_back(new RNNPB_Behaviour());

			init_behaviour(behaviours[i]);
		}

		/*
		 * Print Network
		 * */
		//network->printNetwork();
	}
};
