/**
 * Offline Trainer (Header+Implementation)
 * 
 * Copyright 2013 Fabian Schrodt, FSchrodt@gmx.de
 * 
 * This file is part of RNNPBlib.
 * 
 * RNNPBlib is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License version 3 as published by the Free Software Foundation.
 * 
 * RNNPBlib is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along with RNNPBlib. If not, see http://www.gnu.org/licenses/.
 */

#pragma once

#include "RNNPB_Definitions.h"

#include "RNNPB_TrainingData.h"
#include "RNNPB_CVTrainer.h"

class RNNPB_CVOfflineTrainer : public RNNPB_CVTrainer
{
	unsigned int training_iterations;	//(maximum) number of training samples to generate

	virtual unsigned int getMaxEpoch(unsigned int startb, unsigned int stopb)
	{
		//determine behaviour with most epochs
		unsigned int max_epoch=0;
		for(unsigned int b=startb;b<stopb;b++)
		{
			unsigned int tmp=trainingdata->nrOfEpochs(b)-1;
			if(tmp > max_epoch)
				max_epoch=tmp;
		}
		//cout<<"MAXEPO"<<max_epoch<<endl;
		return max_epoch;
	}

	virtual unsigned int getMaxEpochForBehaviour(unsigned int behaviour)
	{
		return trainingdata->nrOfEpochs(behaviour)-1;
	}

public:

	RNNPB_TrainingDataPerBehaviourContainer* trainingdata;

	void addBehaviour(deque <RNNPB_TrainingDataContainer*> add_epochs)
	{
		behaviours.push_back(new RNNPB_Behaviour());

		init_behaviour(behaviours.back());

		for(unsigned int i=0;i<add_epochs.size();i++)
			trainingdata->trainingEpoch(trainingdata->nrOfBehaviours(), i, add_epochs[i]);
	}

	/*
	 * Checks if the epoch is exceeded.
	 * */
	virtual bool abortEpochConditionMet()
	{
		if(iteration+1 >= (int) trainingdata->data[selected_behaviour][epoch_now]->size())
		{
			return true;
		}
		return false;
	}

	virtual RNNPB_Vector feedNNandApp()
	{
		RNNPB_Vector sensor = *(trainingdata->trainingInput(iteration,selected_behaviour,epoch_now));
		network->input->setInput(sensor);
		//sensor.print();

		RNNPB_Vector action = network->output->getActivation();
		application->action(action);	//CHECK

		RNNPB_Vector feedback = *(trainingdata->trainingTarget(iteration,selected_behaviour,epoch_now)) - action;
		network->output->setFeedback(feedback);
		//feedback.print();

		#ifdef ENABLE_EXPERIMENT
			if(application->get_phase()>1)
			{
			//Write out feedback for each sample
			stringstream exp;
			exp << "Feedback-phase" << application->get_phase() << "-b" << selected_behaviour << "-t" << epoch_now;
			writer.write(exp, &feedback);
			}
		#endif

		return feedback;
	}

	/*
	 * (Re)generate Trainingdata
	 * */
	void generateTrainingData(unsigned int behaviour, unsigned int epoch)
	{
		trainingdata->clearTrainingEpoch(behaviour, epoch);

		application->switch_feedback_method_to(behaviour);

		//initialize simulation
		application->initEpoch();

		unsigned int c=0;
		while(!application->abortEpochConditionMet())
		{
			//cout<<"Training: Iteration "<<c<<", Behaviour "<<behaviour<<", Timeseries "<<epoch<<endl;

			application->iterate();

			application->get_nn_input(trainingdata->trainingInput(c,behaviour,epoch));

			//CHECK: application->action(output) with Fake-Output!??!
			application->action(RNNPB_Vector(0));

			application->get_nn_target(trainingdata->trainingTarget(c,behaviour,epoch));

			c++;
		}

		application->stopEpoch();

		//Repeat failed runs when "goal-reached-criteria" not met:
		if(!(application->epochGoalReached()))
			generateTrainingData(behaviour, epoch);	//repeat
	}

	/*
	 * Generates Trainingdata. Same nr of epochs for each behaviour.
	 * */
	void generateAllTrainingData(unsigned int nr_of_epochs)
	{
		cout<<"RNNPB PLUGIN: OfflineTrainer is gathering Training Data...\n";

		if(trainingdata!=NULL)
			delete trainingdata;

		trainingdata=new RNNPB_TrainingDataPerBehaviourContainer(network->input->size(), network->output->size());

		application->phase=0;

		//GENERATE TRAINING DATA WITHOUT USING THE NETWORK

		//nr. of steps to record:
		for(unsigned int t=0;t<nr_of_epochs;t++)
		{
			for(unsigned int b=0;b<behaviours.size();b++)
			{
				epoch_now=t;
				selected_behaviour=b;
				generateTrainingData(b,t);
			}
		}

		iteration=-1;

		//reset application-iterator:
		application->switch_feedback_method_to(0);
	}

	/*
	 * Save Trainingdata to a file.
	 * */
	void saveTrainingData(const char* filename)
	{
		cout<<"RNNPB PLUGIN: Saving Training Data...\n";

		if(trainingdata==NULL)
		{
			cout<<"Error: No trainingdata to save!\n";
			return;
		}

		std::ofstream outfile(filename);

		// save data to archive
		boost::archive::binary_oarchive boost_archive(outfile);
		// write class instance to archive
		boost_archive << *trainingdata;
		// archive and stream closed when destructors are called
	}

	/*
	 * Loads Trainingdata from a file!
	 * */
	void loadTrainingData(const char* filename)
	{
		std::ifstream infile(filename);
		//TODO: Catch errors!

		if(trainingdata!=NULL)
			delete trainingdata;

		trainingdata=new RNNPB_TrainingDataPerBehaviourContainer(network->input->size(),network->output->size());

		//load boost archive...
		boost::archive::binary_iarchive boost_archive(infile);
		boost_archive >> *trainingdata;

		cout<<"RNNPB PLUGIN: Loaded trainingdata from file '"<<filename<<"' with "<<trainingdata->nrOfBehaviours()<<" behaviors.\n";
	}

	RNNPB_CVOfflineTrainer
		(
				RNNPB_NetworkContainer* set_network,
				RNNPB_ApplicationInterface* set_application,
				unsigned int set_training_iterations,
				unsigned int set_nr_behaviours
		 ) :
			 RNNPB_CVTrainer(set_network, set_application, set_nr_behaviours)
	{
		iteration=-1;

		trainingdata=NULL;

		training_iterations=set_training_iterations;

		//reset application-iterator
		application->switch_feedback_method_to(0);
	}
};
