static char rcsid[] = "$Id: curvematch.c,v 1.5 1997/07/18 03:02:36 dhb Exp $";

/*
** $Log: curvematch.c,v $
** Revision 1.5  1997/07/18 03:02:36  dhb
** Fix for getopt problem; getopt(), optopt and optind are now
** G_getopt(), G_optopt and G_optind.
**
** Revision 1.4  1994/08/08 22:15:55  dhb
** Changes from Upi.
**
** Revision 1.4  1994/06/13  22:45:08  bhalla
** Replaced the missing options for using either a table or a file
**
 * Revision 1.2  1994/06/06  15:58:35  bhalla
 * Changed to use lftab (load_file_or_table) rather than just fopen
 *
 * Revision 1.1  1994/05/31  19:23:45  bhalla
 * Initial revision
 *
** Revision 1.3  1993/02/24  19:15:52  dhb
** Fixed initialization of default values.
**
** Revision 1.2  1993/02/24  18:11:12  dhb
** 1.4 to 2.0 command argument changes.
**
** Revision 1.1  1992/12/11  19:05:51  dhb
** Initial revision
**
*/


/*
** curve_match.c : for estimating the nearness of match of two
** spike trains.
** By Upinder S. Bhalla, 1992, Caltech
*/

#include <stdio.h>
#include <math.h>
#include "header.h"

int find_spike();
float curvematch();
float	est_max();
float	est_min();
float	do_stats();
void	altspline();


#define MAXPTS 5000
#define MAXSPKS 500
#define MIN_SHAPE_SAMPLES 8


float do_curvematch(argc,argv)
	int		argc;
	char	**argv;
{
	FILE	*simfile,*reffile,*fopen();
	float match;
	int i;
	float *sim,*simt,*ref,*reft;
	int 	nsim,nref;
	float	ws,wt,wa,wp;
	int	plot_flag=0,verbose_flag=0;
	int	status;
	float	Atof();

	ws=wt=wa=wp=1.0;

	initopt(argc, argv, "ref-file-or-table sim-file-or-table -amplitude -time -shape -pk-to-pk -weight s t a p -plot -verbose");
	while ((status = G_getopt(argc, argv)) == 1)
	  {
	    if (strcmp(G_optopt, "-shape") == 0)
	      {
		ws=1.0;wt=wa=wp=0.0;
	      }
	    else if (strcmp(G_optopt, "-amplitude") == 0)
	      {
		wa=1.0;wt=ws=wp=0.0;
	      }
	    else if (strcmp(G_optopt, "-time") == 0)
	      {
		wt=1.0;wa=ws=wp=0.0;
	      }
	    else if (strcmp(G_optopt, "-pk-to-pk") == 0)
	      {
		wp=1.0;wt=wa=ws=0.0;
	      }
	    else if (strcmp(G_optopt, "-weight") == 0)
	      {
		    ws=Atof(optargv[1]);
		    wt=Atof(optargv[2]);
		    wa=Atof(optargv[3]);
		    wp=Atof(optargv[4]);
	      }
	    else if (strcmp(G_optopt, "-plot") == 0)
		    plot_flag=1;
	    else if (strcmp(G_optopt, "-verbose") == 0)
		    verbose_flag=1;
	  }

	if (status < 0) {
		printoptusage(argc, argv);
		return(0);
	}

	if ((nref = (load_file_or_table(optargv[1],&reft, &ref))) < 1)
		return(0);
	if ((nsim = (load_file_or_table(optargv[2],&simt, &sim))) < 1) {
		free(ref);
		free(reft);
		return(0);
	}

	match=curvematch(ref,reft,nref,sim,simt,nsim,
		1,0.0,0.0,0.03,0.03,ws,wt,wa,wp,plot_flag,verbose_flag);

	free(sim);
	free(simt);
	free(ref);
	free(reft);
	return(match);
}

float curvematch(ref,reft,nref,sim,simt,nsim,
	autoest,user_maxest, user_minest, maxwindow,minwindow,ws,wt,wa,wp,
	plot_flag,verbose_flag)
	float	*ref;
	float	*reft;
	int		nref;
	float	*sim;
	float	*simt;
	int		nsim;
	int		autoest;
	float	user_maxest,user_minest,maxwindow,minwindow;
	float	ws,wt,wa,wp; /* wts for the 3 params */
	int		plot_flag,verbose_flag;
{
	float	match;
	int		i,j;
	int		nrefspk,nsimspk;
	int		spikeindex,lastspikeindex;
	float	maxest,minest;
	float	refmax_mean;
	float	refmax_err;
	float	refmin_mean;
	float	refmin_err;
	float	simmax_mean;
	float	simmax_err;
	float	simmin_mean;
	float	simmin_err;
	float	max,min,refmin;
	float	refspkwidth,simspkwidth;
	int		hi,lo;
	float	x,y,dy;
	float	sumsqdy=0.0,sumdy=0.0;
	int		*refspk,*simspk,*badspk;
	float	*indarray;
	int		start,end;
	float	ret;
	float	simlo,reflo;
	int		*simsteepsamples,*refsteepsamples;
	float	temp;
	float	sumsqscale=0.0,sumscale=0.0;
	int		npts=0;
	int		nbadspk=0;
	float	*refptp,*simptp;
	FILE	*rout,*sout,*fopen();

	if (plot_flag) {
		rout=fopen("rout","w");
		sout=fopen("sout","w");
	}

	refspk=(int *)calloc(MAXSPKS,sizeof(int));
	simspk=(int *)calloc(MAXSPKS,sizeof(int));
	badspk=(int *)calloc(MAXSPKS,sizeof(int));
	refptp=(float *)calloc(MAXSPKS,sizeof(float));
	simptp=(float *)calloc(MAXSPKS,sizeof(float));
	refsteepsamples=(int *)calloc(MAXSPKS,sizeof(int));
	simsteepsamples=(int *)calloc(MAXSPKS,sizeof(int));
	
	/* Setting the max and min values for the spike peaks */
	if (autoest) {
		maxest = est_max(ref,nref);
		minest = est_min(ref,nref);
	} else {
		maxest = user_maxest;
		minest = user_minest;
	}
	/* Finding the stats for means for reference and sim spikes */
	find_max(maxest,maxwindow,ref,nref,&refmax_mean,&refmax_err);
	find_min(minest,minwindow,ref,nref,&refmin_mean,&refmin_err);
	
	/* initialize assorted variables */
	lastspikeindex=refsteepsamples[0]=refspk[0]=spikeindex=0;
	i = 1;
	temp = (refmax_mean+4.0*refmin_mean)/5.0;

	/* Scan through all spikes, summing the number of points on the
	** spike itself in steepsamples */
	while((spikeindex = find_spike(ref,spikeindex+3,nref,
		maxest,maxwindow))) {
		/* Find the first local min after the spike */
		for (j=spikeindex+1;j<(nref-2);j++){
			min=ref[j];
			if (min<ref[j-1] && min<=ref[j+1] && min<=ref[j+2]) break;
		}
		/* set aside that many points as being too steep */
		if (j<(nref-2))
			refsteepsamples[i]=j-spikeindex;
		else
			break;
		refspk[i]=spikeindex;
		refptp[i]=ref[spikeindex]-min;
		i++;
	}
	refspk[i]=nref-1;
	nrefspk = i+1;

	/* Do it all over again for the sim spikes */
	if (autoest) {
		maxest = est_max(sim,nsim);
		minest = est_min(sim,nsim);
	}
	find_max(maxest,maxwindow,sim,nsim,&simmax_mean,&simmax_err);
	find_min(minest,minwindow,sim,nsim,&simmin_mean,&simmin_err);

	lastspikeindex=simsteepsamples[0]=simspk[0]=spikeindex=0;
	i = 1;
	temp = (simmax_mean+4.0*simmin_mean)/5.0;
	while((spikeindex = find_spike(sim,spikeindex+2,nsim,
		maxest,maxwindow)) && i < nrefspk) {
		/* Find the first local min after the spike */
		for (j=spikeindex+1;j<(nsim-2);j++){
			min=sim[j];
			if (min<sim[j-1] && min<=sim[j+1] && min<=sim[j+2]) break;
		}
		/* set aside that many points as being too steep */
		if (j<(nsim-2))
			simsteepsamples[i]=j-spikeindex;
		else
			break;
		simspk[i]=spikeindex;
		simptp[i]=sim[spikeindex]-min;
		i++;
	}
	simspk[i]=nsim-1;
	nsimspk = i+1;
	/*
	** Check if the spikes are too close to each other to handle
	** There should be at least MIN_SHAPE_SAMPLES samples left over
	** for the shape comparison. 
	*/
	j=0;
	for(i=1;i<nsimspk;i++) {
		if (simspk[i]-simspk[i-1]<simsteepsamples[i]+MIN_SHAPE_SAMPLES){
			badspk[i-1]=1;
			/*
			if (IsSilent() < 2 && j==0) {
				printf (
					"Error in Curvematch: spike interval too small\n");

				j++
			}
			*/
			nbadspk++;
			sumsqscale+=100.0/(float)(nsimspk*nsimspk);
			sumscale+=10.0/(float)nsimspk;
		}
	}

	/* Take the smaller number of spikes */
	if (nrefspk > nsimspk)
		nrefspk = nsimspk;

	for (i = 1 ; i < nrefspk;i++) {
		/* Locate the interspike region, starting at point 0 */
		if (badspk[i-1]) continue;
		start = refspk[i-1]+refsteepsamples[i-1];
		end = refspk[i]-refsteepsamples[i];
		lo = simspk[i-1]+simsteepsamples[i-1];
		hi = simspk[i]-simsteepsamples[i];
		/* Avoid any errors due to missing points */
		if (start>=end || lo>=hi || lo<0 || hi<0 || start<0 || end<0) {
			if (IsSilent() < 2)
				printf (
					"Error in Curvematch: spike does not repolarize\n");
			if (i>1) { /* we have had at least one spike go OK */
				/* Factor in the error due to the truncated trace */
				sumscale+=10.0* (float)(nsim-simspk[i])/(float)(nsim);
				/* Count up the spikes that were missed out */
				nbadspk+=nrefspk-i;
				break;
			} else { /* This one gave us no useful info at all */
				return(10);
			}
		}
		/* find scale factor for interspike region */
		reflo = reft[start];
		simlo = simt[lo];
		refspkwidth = reft[end]-reflo;
		simspkwidth = simt[hi]-simlo;
		if (simspkwidth>0.0 && refspkwidth>0.0) {
			temp=refspkwidth/simspkwidth;
			/* temp=1.0/(temp+1.0/temp-1.0); */
			temp+=1.0/temp - 2.0;
			sumsqscale+=temp*temp;
			sumscale+=temp;
		}

		max = min=ref[start];
		for (j=start;j<=end;j++) {
			if (min > ref[j]) {
				min = ref[j];
			} else if (max < ref[j]) {
				max = ref[j];
			}
		}
		refmin = min;
		temp = max - min;
		max = min=sim[lo];
		for (j=lo;j<=hi;j++) {
			if (min > sim[j]) {
				min = sim[j];
			} else if (max < sim[j]) {
				max = sim[j];
			}
		}
		if (max<=min)
			temp = 1.0;
		else
			temp /= max - min;
		for (j=lo;j<=hi;j++) {
			/* we scale x and y for this pt on the sim plot to
			** the equivalent on the ref plot
			*/
			x = reflo + refspkwidth * (simt[j]-simlo)/simspkwidth;
			if (x>reft[end]) {
				if (IsSilent() < 2)
				printf("Error : interpolating outside ref waveform\n");
				break;
			}
			/*
			y = refmin_mean + (refmax_mean - refmin_mean) *
				(sim[j]-simmin_mean)/(simmax_mean-simmin_mean);
			*/
			y = refmin + temp * (sim[j]-min);

			if (linterp(reft+start,ref+start,end-start,x,&ret)) {
				dy = y - ret;
				sumsqdy += dy * dy;
				sumdy += dy;
				npts++;
				if (plot_flag) {
					fprintf(rout,"%f	%f\n",x,ret);
					fprintf(sout,"%f	%f\n",x,y);
				}
			}
		}
	}
	if (nbadspk > nrefspk/2) {
		if (IsSilent() < 2)
			printf("Warning : Too many bad spikes = %d out of %d\n",
				nbadspk,nrefspk);
		if (nbadspk>=nrefspk)
			return(10);
	}
	match = do_stats(sumsqdy,sumdy,sumsqscale,sumscale,nrefspk-nbadspk,npts,refmax_mean,refmax_err,
		refmin_mean,refmin_err,simmax_mean,simmax_err,
		simmin_mean,simmin_err,refptp,simptp,
		ws,wt,wa,wp,verbose_flag);
	free(refspk);
	free(simspk);
	free(badspk);
	free(refptp);
	free(simptp);
	free(refsteepsamples);
	free(simsteepsamples);
	if (plot_flag) {
		fclose(rout);
		fclose(sout);
	}
	return(match);
}

int find_spike(arr,start,npts,maxest,maxwindow)
	float	*arr;
	int		start;
	int		npts;
	float	maxest;
	float	maxwindow;
{
	int i;
	float hi,lo;
	float y;

	hi=maxest+maxwindow;
	lo=maxest-maxwindow;

	if (start > MAXPTS - 3)
		return(0);

	for (i=start+2;i<(npts - 2);i++) {
		y = arr[i];
		if (y > lo && y < hi) {
			if (y>arr[i-2] && y>arr[i-1] && y>=arr[i+1] && y>=arr[i+2])
				return(i);
		}
	}
	return(0);
}

find_max(est,window,arr,npts,mean,err)
	float est;
	float window;
	float *arr;
	int		npts;
	float *mean;
	float *err;
{
	int i;
	float hi,lo;
	float y;
	float sum=0.0,sumsq=0.0;
	float nsum=0.0;

	hi=est+window;
	lo=est-window;

	for (i=2;i<(npts - 2);i++) {
		y = arr[i];
		if (y > lo && y < hi) {
			if (y>arr[i-2] && y>arr[i-1] && y>arr[i+1] && y>arr[i+2]){
				sum+=y;
				nsum+=1.0;
				sumsq+=y*y;
			}
		}
	}

	*mean = 0.0;
	*err = 0.0;

	if (nsum > 0.5) {
		sumsq=sumsq-sum*sum/nsum;
		if (sumsq>0)
			*err=sqrt(sumsq)/nsum;
		*mean=sum/nsum;
	} else {
		*mean = est;
		*err=0.0;
	}
}

find_min(est,window,arr,npts,mean,err)
	float est;
	float window;
	float *arr;
	int		npts;
	float *mean;
	float *err;
{
	int i;
	float hi,lo;
	float y;
	float sum=0.0,sumsq=0.0;
	float nsum=0.0;

	hi=est+window;
	lo=est-window;

	for (i=2;i<(npts - 2);i++) {
		y = arr[i];
		if (y > lo && y < hi) {
			if (y<arr[i-2] && y<arr[i-1] && y<arr[i+1] && y<arr[i+2]){
				sum+=y;
				nsum+=1.0;
				sumsq+=y*y;
			}
		}
	}

	if (nsum > 0.5) {
		sumsq=sumsq-sum*sum/nsum;
		if (sumsq>0)
			*err=sqrt(sumsq)/nsum;
		*mean=sum/nsum;
	} else {
		*mean = est;
		*err=0.0;
	}
}

float est_max(arr,npts)
	float *arr;
	int		npts;
{
	int i;
	float max;

	if (npts <= 0)
		return(0.0);

	max = arr[0];

	for(i=1;i<npts;i++)
		if (max < arr[i])
			max = arr[i];
	return(max);
}

float est_min(arr,npts)
	float *arr;
	int		npts;
{
	int i;
	float min;

	if (npts <= 0)
		return(0.0);

	min = arr[0];

	for(i=1;i<npts;i++)
		if (min > arr[i])
			min = arr[i];
	return(min);
}


/*
** The idea is that each of the match parameters should have a
** range of 0 to 1 (perfect to no match) 
** The parms are : shape, durations, ampl
*/
float do_stats(sumsqdy,sumdy,sumsqscale,sumscale,nrefspk,npts,refmax_mean,refmax_err,
		refmin_mean,refmin_err,simmax_mean,simmax_err,
		simmin_mean,simmin_err,refptp,simptp,ws,wt,wa,wp,verbose_flag)
	float sumsqdy;
	float sumdy;
	float sumsqscale;
	float sumscale;
	int nrefspk;
	int	npts;
	float refmax_mean;
	float refmax_err;
	float refmin_mean;
	float refmin_err;
	float simmax_mean;
	float simmax_err;
	float simmin_mean;
	float simmin_err;
	float *refptp,*simptp;
	float ws,wt,wa,wp;
	int		verbose_flag;
{
	float shape_match;
	float time_match;
	float ampl_match;
	float ptp_match;
	float temp;
	int i;

	/*
	** Rationale : sumscale keeps track of a quantity which is zero
	** when there is a perfect match, and grows approx linearly with
	** scale mismatch. So we just need to divide by the number of
	** spikes to get the average diff from perfect match. Now we
	** need to condense into the range 0 to 1 (1 is no match). This
	** is simply 1-exp(-x).
	*/
	if (nrefspk >= 1) {
		sumscale/=(float)nrefspk;
		time_match=1.0-exp(-sumscale);
	} else
	time_match=0.0;

	/*
	if (nrefspk > 1) {
		sumsqscale-=sumscale*sumscale/(float)(nrefspk-1);
		if (sumsqscale > 0)
			sumsqscale=sqrt(sumsqscale);
		else
			sumsqscale=0;
	} else { 
		sumsqscale=0;
	}
	sumscale+=sumsqscale;
	sumscale/=(float)(nrefspk);
	time_match=1.0-2.0/(sumscale+1.0/sumscale);
	*/

	ampl_match =
		/* Diff of maxes */
		((refmax_mean-simmax_mean)*(refmax_mean-simmax_mean)+
		/* Diff of mins */
		(refmin_mean-simmin_mean)*(refmin_mean-simmin_mean)+
		/* Diff of amplitudes */
		2.0 * (refmax_mean-refmin_mean-(simmax_mean-simmin_mean)) * 
		(refmax_mean-refmin_mean-(simmax_mean-simmin_mean))) /
		/* amplitudes squared, summed, times 2 */
		(2.0 * ((refmax_mean-refmin_mean) * (refmax_mean-refmin_mean) +
		(simmax_mean-simmin_mean) * (simmax_mean-simmin_mean)) *
		/* The number of spikes being summed over */
		(float)nrefspk);
	/* To make it more linear */
	ampl_match=sqrt(ampl_match);

	ptp_match=0.0;
	if (wp>0.0) { /* normal operation */
		for(i=0;i<nrefspk;i++){
			temp=refptp[i]-simptp[i];
			ptp_match+=temp*temp;
		}
	} else { /* weight only if the ptp is too small */
		for(i=0;i<nrefspk;i++){
			temp=refptp[i]-simptp[i];
			if (temp > 0.0)
				ptp_match+=temp*temp;
		}
	}

	if (ptp_match>0 && nrefspk>= 1) {
		ptp_match=sqrt(ptp_match)/(float)nrefspk;
	}
	if (npts > 1) {
	/* Rationale : 0.1 is about peak-to-peak. The sqr err for each pt
	** should always be smaller. Since the values tend to be a LOT
	** smaller, we take an additional square root */
		if (sumsqdy>0.0) {
			sumsqdy=sqrt(sumsqdy);
			shape_match=sqrt(sumsqdy/(0.1*(float)npts));
		} else {
			shape_match = 0.0;
		}
		/*
		sumsqdy-=sumdy*sumdy/(float)(npts);
		if (sumsqdy > 0.0)
			shape_match=sumsqdy/sqrt((float)(npts));
		else
			shape_match=0.0;
		*/
	} else 
		shape_match=1.0;

	if (verbose_flag)
		printf("shape_match=%f time_match=%f ampl_match=%f ptp_match=%f\n",
			shape_match,time_match,ampl_match,ptp_match);
	return (wt*time_match+wa*ampl_match+ws*shape_match+
		ptp_match*fabs(wp));
}



#ifdef OLD

void altspline(x,y,n,yp1,ypn,y2)
float *x,*y,yp1,ypn,*y2;
int n;
{
	int i,k;
	float p,qn,sig,un,*u;

	u = (float *)calloc(n,sizeof(float));

	if (yp1 > 0.99e30)
		y2[0]=u[0]=0.0;
	else {
		y2[0] = -0.5;
		u[0]=(3.0/(x[1]-x[0]))*((y[1]-y[0])/(x[1]-x[0])-yp1);
	}
	for (i=1;i<n-1;i++) {
		sig=(x[i]-x[i-1])/(x[i+1]-x[i-1]);
		p=sig*y2[i-1]+2.0;
		y2[i]=(sig-1.0)/p;
		u[i]=(y[i+1]-y[i])/(x[i+1]-x[i]) - (y[i]-y[i-1])/(x[i]-x[i-1]);
		u[i]=(6.0*u[i]/(x[i+1]-x[i-1])-sig*u[i-1])/p;
	}
	if (ypn > 0.99e30)
		qn=un=0.0;
	else {
		qn=0.5;
		un=(3.0/(x[n-1]-x[n-2]))*(ypn-(y[n-1]-y[n-2])/(x[n-1]-x[n-2]));
	}
	y2[n-1]=(un-qn*u[n-2])/(qn*y2[n-2]+1.0);
	for (k=n-2;k>=0;k--)
		y2[k]=y2[k]*y2[k+1]+u[k];
	free(u);
}



/*
** altsplint is a version of splint.
*/
int altsplint(xa,ya,y2a,n,x,y)
float *xa,*ya,*y2a,x,*y;
int n;
{
	int klo,khi,k;
	float h,b,a;
	float y2lo,y2hi;

	klo=0;
	khi=n-1;
	while (khi-klo > 1) {
		k=(khi+klo) >> 1;
		if (xa[k] > x) khi=k;
		else klo=k;
	}
	h=xa[khi]-xa[klo];
	if (h == 0.0) {
		printf("phooo! : Bad XA input to altsplint in curvematch\n");
		return(0);
	}
	a=(xa[khi]-x)/h;
	b=(x-xa[klo])/h;
	y2lo=y2a[klo];
	y2hi=y2a[khi];
	*y=a*ya[klo]+b*ya[khi]+((a*a*a-a)*y2lo+(b*b*b-b)*y2hi)*(h*h)/6.0;
	return(1);
}
#endif

/* linterp does linear interpolatoon */
int linterp(xa,ya,n,x,y)
float *xa,*ya,x,*y;
int n;
{
	int klo,khi,k;
	float h,b,a;

	klo=0;
	/* khi=n-1; */
	khi=n;
	while (khi-klo > 1) {
		k=(khi+klo) >> 1;
		if (xa[k] > x) khi=k;
		else klo=k;
	}
	h=xa[khi]-xa[klo];
	if (h == 0.0) {
		printf("phooo! : Bad XA input to linterp in curvematch\n");
		return(0);
	}
	*y= ya[klo]+ (ya[khi]-ya[klo])*(x-xa[klo])/h;
	return(1);
}

#undef MAXPTS
#undef MAXSPKS
#undef MIN_SHAPE_SAMPLES
