/* Programm, das ein Perzeptron trainiert und mit dem Tower-Algo erweitert */
/* Achtung: weil das dann noch mehr Flexibilitaet schenkt, gibt's immer */
/* Verbindungen von allen (!!!) Vorgaengerneuronen. In der */
/* Vorlesung je immer nur vom letzten */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <malloc.h>
#include <time.h>

#define MAX 1000 	/* maximale Zyklen */
#define MAXW 100 	/* maximale Anzahl Neuronen */

int num_pat;    /* Anzahl der Pattern */
int DIM;	/* Dimension */
int DIMW;	/* aktuelle Dimension des Perzeptrons, die wird bei obiger
                   Modifikation immer groesser */
int W;		/* Anzahl Perzeptronen */
float** pattern;	/* Pattern */
float** w; 	/* Gewichte */

/**********************************************************************/
/* pattern anlegen und einlesen */

int get_pattern(char* name)
{
	FILE* fp;
	int i,j;
	if ((fp=(FILE*)fopen(name,"r"))==NULL)
	{ fprintf(stderr,"cannot open file\n"); exit(1); }
	if (!(fscanf(fp,"%d",&num_pat)))
	{ fprintf(stderr,"wrong pattern number\n"); exit(1); }
	if (num_pat<1)
	{ fprintf(stderr,"wrong pattern number\n"); exit(1); }
	if (!(fscanf(fp,"%d",&DIM)))
	{ fprintf(stderr,"wrong input dimension\n"); exit(1); }
	if (DIM<1)
	{ fprintf(stderr,"wrong input dimension\n"); exit(1); }
	if ((pattern=(float**)calloc(num_pat,sizeof(float*)))==NULL)
	{ fprintf(stderr,"not enough space\n"); exit(1); }
	for (i=0;i<num_pat;i++)
		if ((pattern[i]=(float*)calloc(DIM+1,sizeof(float)))==NULL)
		{ fprintf(stderr,"not enough space\n"); exit(1); }
	for (i=0;i<num_pat;i++)
	{
		for (j=0;j<DIM+1;j++)
		if (!(fscanf(fp,"%f",&(pattern[i][j]))))
		{ fprintf(stderr,"wrong pattern\n"); exit(1); }
		if ((pattern[i][DIM]!=0.0)&&(pattern[i][DIM]!=1.0))
		{ fprintf(stderr,"wrong pattern\n"); exit(1); }
	}
	fclose(fp);
	DIMW=DIM;
	return(0);
}

/**********************************************************************/
/* gewuenschter Wert */

float soll(float* p)
{
	return (p[DIMW]);
}

/**********************************************************************/
/* tatsaechlicher Wert */

float ist(float* p)
{
	int i;
	float wert=0.0;
	for (i=0;i<DIMW;i++)
		wert+=w[W-1][i]*p[i];
	wert-=w[W-1][DIMW];
	if (wert>=0) return (1.0);
	return(0.0);
}

/**********************************************************************/
/* init */

int init()
{
	int i;
	if ((w=(float**)calloc(1,sizeof(float*)))==NULL)
        { fprintf(stderr,"not enough space\n"); exit(1); } 
	W=1;
	if ((w[0]=(float*)calloc(DIM+1,sizeof(float)))==NULL)
	{ fprintf(stderr,"not enough space\n"); exit(1); }
	for (i=0;i<DIM+1;i++)
		w[0][i]=0.0;
	return(0);	
}

/**********************************************************************/
/* neues Perzeptron allocieren und die Muster ausdehnen */

int verl_w()
{
	int i,j;
	if ((w=(float**)realloc(w,(W+1)*sizeof(float*)))==NULL)
        { fprintf(stderr,"not enough space\n"); exit(1); } 
	if ((w[W]=(float*)calloc(DIMW+2,sizeof(float)))==NULL)
	{ fprintf(stderr,"not enough space\n"); exit(1); }
	for (i=0;i<DIMW+2;i++)
		w[W][i]=0.0;
	for (i=0;i<num_pat;i++)
	{
		if ((pattern[i]=
		(float*)realloc(pattern[i],(DIMW+2)*sizeof(float)))==NULL)
		{ fprintf(stderr,"not enough space\n"); exit(1); }
		pattern[i][DIMW+1]=pattern[i][DIMW];
		pattern[i][DIMW]=ist(pattern[i]);
	}
	W++;
	DIMW++; 
#ifdef DEBUG
	for (i=0;i<num_pat;i++)
	{
		for (j=0;j<DIMW+1;j++)
			fprintf(stderr,"%lf ",pattern[i][j]);
		fprintf(stderr,"\n");
	}
#endif
	return(1);
}

/**********************************************************************/
/* free */

int free_all()
{
	int i;
	for (i=0;i<num_pat;i++)
		free(pattern[i]);
	free(pattern);
	for (i=0;i<W;i++)
		free(w[i]);
	free(w);
	return(0);
}

/**********************************************************************/
/* Aendern der Gewichte um aktuelles Pattern */

void dw(float* p)
{
	int i;			
	float vz=1.0;
	if (p[DIMW]==0.0) vz=-1.0;
	for (i=0;i<DIMW;i++)
	{
		w[W-1][i]+=vz*p[i];
	}
	w[W-1][DIMW]-=vz;
}

/**********************************************************************/
/* Anzahl falsch klassifizierte pattern */

int guete()
{
	int i;
	int anz=0;
	for (i=0;i<num_pat;i++)
	{
		if (soll(pattern[i])==ist(pattern[i])) anz++;
	}
	return (anz);
}

/**********************************************************************/
/* Perzeptronalgorithmus starten, maximal MAX Zyklen */
/* total stumpfes Training mit Zufallsmustern, weil das im Gegensatz */
/* zum pocket im Mittel insb. bei spirals viel besser tat */
/**********************************************************************/

int perceptron()
{
	int i,zyk=0,	/* no. cycles */
	    help,anz,all;
	float val1,val2;
	anz=0;
	while ((zyk<MAX)&&(anz<num_pat))
	{
		zyk++;
		i=random()%num_pat;		/* Pattern ziehen */
		val1=soll(pattern[i]);
		val2=ist(pattern[i]);
		if (val1!=val2)			/* falsch */
		{
			dw(pattern[i]);
		}
		anz=guete();
#ifdef DEBUG3
		fprintf(stderr,"Zyklen %d, anz %d,\nvektor",zyk,anz);
		for (i=0;i<DIMW+1;i++)
			fprintf(stderr,"%f ",w[W-1][i]);
		fprintf(stderr,"\n");
#endif	
	}
	return(zyk);
}	

/**********************************************************************/
/* Rekursiver Tower, maximal MAXW Neuronen */
/**********************************************************************/

int tower()
{
	int number=0,i;
	while ((number<num_pat) && (W<MAXW))
	{
		perceptron();
		number=guete();
		fprintf(stderr,"%d korrekt\n",number);
#ifdef DEBUG
		for (i=0;i<DIMW+1;i++)
			fprintf(stderr,"%f ",w[W-1][i]);
		fprintf(stderr,"\n");
#endif
		verl_w();
	}
	free(w[W-1]);
	W--;
	return(number);
}

/**********************************************************************/
/* Ausgabe */
/**********************************************************************/

int out()
{
	int i,j;
	for (i=0;i<W;i++)
	{
		fprintf(stdout,"Perzeptron Nummer %d\n",i+1);
		for (j=0;j<DIM+1+i;j++)
			fprintf(stdout,"%f ",w[i][j]);
		fprintf(stdout,"\n");
	}
	return(0);
}

/**********************************************************************/
/* tatsaechlicher Wert von Neuron I */

float isti(float* p, int I)
{
	int i;
	float wert=0.0;
	if (I>W) 
	{
		fprintf(stderr,"haeh?"); return(0.0);
	}
	for (i=0;i<DIM+I-1;i++)
		wert+=w[I-1][i]*p[i];
	wert-=w[I-1][DIM+I-1];
#ifdef DEBUG3
	fprintf(stderr,"%d -- berechne:\n",I);
	for (i=0;i<DIM+I-1;i++)
		fprintf(stderr,"%f * %f +",w[I-1][i],p[i]);
	fprintf(stderr,"- %f\n",w[I-1][DIM+I-1]);
#endif
	if (wert>=0) return (1.0);
	return(0.0);
}

/**********************************************************************/
/* nur zwecks Gnuplot !!! */
/**********************************************************************/
/* Berechnen totale Ausgabe fuer 2-dim tower bis was */
/**********************************************************************/

int total(int was, float x, float y)
{	
	float* erg;
	int ergebnis=0;
	int i;
	if ((erg=(float*)calloc(was+2,sizeof(float)))==NULL)
	{ fprintf(stderr,"not enough space\n"); exit(1); }
	erg[0]=x; erg[1]=y;
	for (i=0;i<was;i++)
		erg[i+2]=isti(erg,i+1);
	if (erg[was+1]>0) ergebnis=1;
	free(erg);
	return(ergebnis);
}

/**********************************************************************/
/* Punkte mit 1 werden in Datei geschrieben, sofern DIM=2 */
/**********************************************************************/

int show(int was)
{
	static float x,y,dx,dy;
	static float vonx,vony,bisx,bisy;
	static int FIRST=1;
	int outxy;
	FILE* fp;
	int i,j;
	if (DIM!=2)
	{ fprintf(stderr,"sorry!\n"); return(0); }
	if (FIRST)
	{
		FIRST=0;
		fprintf(stdout,"vonx? ");
		scanf("%f",&vonx);
		fprintf(stdout,"\nvony? ");
		scanf("%f",&vony);
		fprintf(stdout,"\nbisx? ");
		scanf("%f",&bisx);
		fprintf(stdout,"\nbisy? ");
		scanf("%f",&bisy);
		fprintf(stdout,"\n");
		if ((vonx>=bisx)||(vony>=bisy))
		{ fprintf(stderr,"sorry!\n"); return(0); }
		dx=(bisx-vonx)/50.0;
		dy=(bisy-vony)/50.0;
	}
	if ((fp=fopen("out","w"))==NULL)
	{ fprintf(stderr,"sorry!\n"); return(0); }
	for (i=0;i<=50;i++)
	for (j=0;j<=50;j++)
	{
		x=vonx+i*dx;
		y=vony+j*dy;
		outxy=total(was,x,y);
		if (outxy) fprintf(fp,"%f %f\n",x,y);
#ifdef DEBUG
		fprintf(stderr,"%d %d: %f %f -> %d\n",i,j,x,y,outxy);
#endif
	}
	fclose(fp);
	return(0);
}

/**********************************************************************/
/* Gnuplot special Ende */
/**********************************************************************/

int main(int argc,char **argv)
{
	int number,i;
	if (argc<2)
	{
		fprintf(stderr,"usage: training <file with pattern>[show]\n");
		return(0);
	}
	srandom(time(0));
	get_pattern(argv[1]);
	init();
	number=tower();
	fprintf(stderr,"%d richtig erkannt\n",number);
	out();
#ifdef DEBUG
	for (i=0;i<W;i++)
	{
		show(i+1);
		fscanf(stdin,"%c");
	}
#endif
	if (argc>2)
		show(W);
	free_all();
	return(0);
}
