/***************************************************************************                  
/* Jose V. Manjon - jmanjon@fis.upv.es                                     */
/* Universidad Politecnica de Valencia, Spain                              */
/* Pierrick Coupe - pierrick.coupe@gmail.com                               */
/* Brain Imaging Center, Montreal Neurological Institute.                  */
/* Mc Gill University                                                      */
/*                                                                         */
/* Copyright (C) 2010 Jose V. Manjon and Pierrick Coupe                    */
/*                                                                         */
/***************************************************************************
* MRI Superresolution Using Self Similarity and Image Priors               *
* Jos V. Manjn, Pierrick Coup, Antonio Buades,D. Louis Collins          * 
* and Montserrat Robles                                                    * 
***************************************************************************/

#include "math.h"
#include "mex.h"
#include <stdlib.h>
#include "matrix.h"
// undef needed for LCC compiler
#undef EXTERN_C
#include <windows.h>
#include <process.h>  

typedef struct{
    int rows;
    int cols;
    int slices;
    double * in_image;
    double * out_image;
    double * ref_image;
    double * pesos;    
    int ini;
    int fin;
    int radio;
    int sigmaI;
    int sigmaS;
}myargument;

static double distance(double* ima,int x,int y,int z,int nx,int ny,int nz,int f,int sx,int sy,int sz)
{

double d,acu,distancetotal,inc;
int i,j,k,ni1,nj1,ni2,nj2,nk1,nk2,kk;

acu=0;
distancetotal=0;
	
for(k=-f;k<=f;k++)
{
	for(i=-f;i<=f;i++)
	{
		for(j=-f;j<=f;j++)
		{
			ni1=x+i;
			nj1=y+j;
			nk1=z+k;
			ni2=nx+i;
			nj2=ny+j;
			nk2=nz+k;
			
			if(ni1<0) ni1=-ni1;
			if(nj1<0) nj1=-nj1;
			if(ni2<0) ni2=-ni2;
			if(nj2<0) nj2=-nj2;
			if(nk1<0) nk1=-nk1;
			if(nk2<0) nk2=-nk2;
			
			if(ni1>=sx) ni1=2*sx-ni1-1;
			if(nj1>=sy) nj1=2*sy-nj1-1;
			if(nk1>=sz) nk1=2*sz-nk1-1;
			if(ni2>=sx) ni2=2*sx-ni2-1;
			if(nj2>=sy) nj2=2*sy-nj2-1;
			if(nk2>=sz) nk2=2*sz-nk2-1;
			
            inc=1;
            if(i==0 & j==0 & k==0)inc=2;
            
			distancetotal = distancetotal + inc*((ima[nk1*(sx*sy)+(nj1*sx)+ni1]-ima[nk2*(sx*sy)+(nj2*sx)+ni2])*(ima[nk1*(sx*sy)+(nj1*sx)+ni1]-ima[nk2*(sx*sy)+(nj2*sx)+ni2]));
			acu=acu + inc;
		}
	}
}

d=distancetotal/acu;

return d;

}


unsigned __stdcall ThreadFunc( void* pArguments )
{
    double *ima,*fima,*ref,*pesos,val,sigS,sigI,w,nv,d;
    int ii,jj,kk,ni,nj,nk,i,j,k,ini,fin,rows,cols,slices,v,p,p1,dg;

    myargument arg;
    arg=*(myargument *) pArguments;

    rows=arg.rows;    
    cols=arg.cols;
    slices=arg.slices;
    ini=arg.ini;    
    fin=arg.fin;
    ima=arg.in_image;
    fima=arg.out_image;
    ref=arg.ref_image;
    pesos=arg.pesos;
    v=arg.radio;
    sigI=arg.sigmaI;
    sigS=arg.sigmaS;    
    
    /* filter*/
    for(k=ini;k<fin;k++)
    {
	for(j=0;j<rows;j++)
	{
		for(i=0;i<cols;i++)
		{		
            p=k*cols*rows+(j*cols)+i;
                      
			for(kk=0;kk<=v;kk++)
            {
			for(ii=-v;ii<=v;ii++)
			{
			for(jj=-v;jj<=v;jj++)
			{
				ni=i+ii;
				nj=j+jj;
				nk=k+kk;
                
                if(kk==0 && jj<0) continue;             
                if(kk==0 && jj==0 && ii<=0) continue;  
				
				if(ni>=0 && nj>=0 && nk>=0 && ni<cols && nj<rows && nk<slices)
				{
                     p1=nk*cols*rows+(nj*cols)+ni;
                                          
                     d=(ref[p]-ref[p1]);
                     d=d*d;
                     
                     if(d>3*sigI) continue;                     
                     
                     dg=distance(ima,i,j,k,ni,nj,nk,1,cols,rows,slices);	                                    
                                                             
                     w=exp(-(d/sigI+dg/sigS));
                     
                     fima[p] += w*ima[p1];
                     pesos[p]+= w;
                
                     fima[p1] += w*ima[p];
                     pesos[p1]+= w;                                                            
				}		 
            }
			}							
            }                        			
		}
	}
    }

    _endthreadex( 0 );
    return 0;
} 


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{

/*Declarations*/
mxArray *xData,*xtmp,*xpesos;
double *ima, *fima,*ref,*tmp,*pesos,*bipesos;
double average,totalweight,off,media,bw;
mxArray *pv;
double dg,h,w,wmax,d,var,t1,t1i,t2,hh,min,SNR,SNR2,level2,level;
int ini,fin,i,j,k,ii,jj,kk,ni,nj,nk,v,f,ndim,indice,Ndims,init,r,un,p,p1,Nthreads;
const int  *dims ;
bool salir;

myargument *ThreadArgs;  
HANDLE *ThreadList; // Handles to the worker threads     

/*Copy input pointer x*/
xData = prhs[0];

/*Get matrix x*/
ref = mxGetPr(xData);

ndim = mxGetNumberOfDimensions(prhs[0]);
dims= mxGetDimensions(prhs[0]);

/* image*/
xData = prhs[1];
ima = mxGetPr(xData);

pv = prhs[2];
/*Get the Integer*/
v = (int)(mxGetScalar(pv));

/*Copy input parameters*/
pv = prhs[3];
/*Get the Integer*/
f = (int)(mxGetScalar(pv));

pv = prhs[4];
level = (double)(mxGetScalar(pv));
level2=level*level;

pv = prhs[5];
/*Get the Integer*/
Nthreads = (int)(mxGetScalar(pv));

/*Allocate memory and assign output pointer*/

plhs[0] = mxCreateNumericArray(ndim,dims,mxDOUBLE_CLASS, mxREAL);
fima = mxGetPr(plhs[0]);

/* low resolution version */
xtmp = mxCreateNumericArray(ndim,dims,mxDOUBLE_CLASS, mxREAL);
tmp = mxGetPr(xtmp);
xpesos = mxCreateNumericArray(ndim,dims,mxDOUBLE_CLASS, mxREAL);
pesos = mxGetPr(xpesos);

for(i=0;i<dims[0]*dims[1]*dims[2];i++)
{
    pesos[i]=1.0;
    fima[i]=ima[i];
}

for(k=0;k<dims[2]/f;k++)
for(j=0;j<dims[1];j++)
for(i=0;i<dims[0];i++)
{
    media=0;
    for (ii=0;ii<f;ii++) media+=ima[(k*f+ii)*(dims[0]*dims[1])+(j*dims[0])+i];     
    tmp[k*(dims[0]*dims[1])+(j*dims[0])+i]=media/f;     
}

// Reserve room for handles of threads in ThreadList
ThreadList = (HANDLE*)malloc(Nthreads* sizeof( HANDLE ));
ThreadArgs = (myargument*) malloc( Nthreads*sizeof(myargument));
    

for (i=0; i<Nthreads; i++)
    {       
	// Make Thread Structure
    ini=(i*dims[2])/Nthreads;
    fin=((i+1)*dims[2])/Nthreads;
    
    ThreadArgs[i].rows=dims[1];
	ThreadArgs[i].cols=dims[0];
   	ThreadArgs[i].slices=dims[2];
    ThreadArgs[i].in_image=ima;
	ThreadArgs[i].out_image=fima;
  	ThreadArgs[i].ref_image=ref;
    ThreadArgs[i].pesos=pesos;
    ThreadArgs[i].ini=ini;
    ThreadArgs[i].fin=fin;
    ThreadArgs[i].radio=v;    
    ThreadArgs[i].sigmaI=level2;    
    ThreadArgs[i].sigmaS=256*level2;       
	
    ThreadList[i] = (HANDLE)_beginthreadex( NULL, 0, &ThreadFunc, &ThreadArgs[i] , 0, NULL );
  }
    
  for (i=0; i<Nthreads; i++) { WaitForSingleObject(ThreadList[i], INFINITE); }
  for (i=0; i<Nthreads; i++) { CloseHandle( ThreadList[i] ); }
    
  free(ThreadArgs); 
  free(ThreadList);


  
for(i=0;i<dims[0]*dims[1]*dims[2];i++)
{
    fima[i]/=pesos[i];
}

// apply constrains

for(k=0;k<dims[2];k=k+f)
for(j=0;j<dims[1];j++)
for(i=0;i<dims[0];i++)
{
    salir=false;
    media=0;
    for (ii=0;ii<f;ii++) media+=fima[(k+ii)*dims[0]*dims[1]+j*dims[0]+i];     
    media=media/f;       
    off=tmp[(k/f)*dims[0]*dims[1]+j*dims[0]+i]-media;
    for (ii=0;ii<f;ii++)
    {
        fima[(k+ii)*dims[0]*dims[1]+j*dims[0]+i]+=off;
        if(fima[(k+ii)*dims[0]*dims[1]+j*dims[0]+i]<0) salir=true;          
    }
    if(salir)
    {
        for (ii=0;ii<f;ii++) fima[(k+ii)*dims[0]*dims[1]+j*dims[0]+i]=ima[(k+ii)*dims[0]*dims[1]+j*dims[0]+i];    
    }
}

return;
}

