#include "kinematics.h"

#ifdef RTAPI

#include "rtapi_math.h"
#include "rtapi.h"      /* RTAPI realtime OS API */
#include "rtapi_app.h"      /* RTAPI realtime module decls */
#include "hal.h"

//#include <stdio.h>
//#include <stdlib.h>
//#include <math.h>

// comp --install millkins.c 


struct haldata {
    hal_float_t *skew;
} *haldata;

#define DEBUG 0

double d2r=0.017453292519943295474371680598;      // degrees to radians coeff
double r2d=57.295779513082322864647721871734;
double zlen=2100;
double zradius=70;
double zang=10;  // angle from around 0, +-45 degrees.

double xlen=1020;  // stick lengths
double ylen=1000;

double xdsth=-30;  // x/y rails are 30mm below "nominal" tool location
double ydsth=-30;
double xdstx=-1067+50;  // distance from nominal tool location to x rail at travel=0
double xdsty=-50;   // x rail y distance, necessary to stop the arm from snagging
double ydsty=1036-40;   // distance from nominal tool location to y rail at travel=0
double ydstx=300;    // y rail x distance, necessary to stop the arm from snagging
double xdst=200; // 0 to 200mm of travel, or whatever the physics allows
double ydst=200; // 0 to 200mm of travel, "

double xctrz=70;    // nominal tool location to center of fixed Z pivot
double yctrz=50;
double zctrz=2100; 

double lastX=0.0, lastY=0.0, lastZ=0.0;
double lastA=0.0, lastB=0.0, lastC=0.0;

double toolX=0, toolY=-20, toolZ=0;  // add these to calculated carte coords
                                    // or subtract from given carte coords.
double rackX=0, rackY=0, rackZ=0;  // add these

double dst(double a, double b, double c)
 {
   return sqrt( a*a+b*b+c*c );
 }

// turns out the wording is the opposite of what I thought, forward is joint->euclid.
int kinematicsForward(const double *joints,
              EmcPose * pos,
              const KINEMATICS_FORWARD_FLAGS * fflags,
              KINEMATICS_INVERSE_FLAGS * iflags)
{
    double inputA, inputB, inputC;
    double outputX, outputY, outputZ;
   double x,y,z, xn,zn,yn;
   double xh,yh,zh;
   double xw, yw, zw;
   double a,b,c, k,e;
   int i;
   double xd, yd, zd;
   double rxlt[3], rylt[3], rzlt[3];
   double tax, tay, taz;
   double rxl, ryl, rzl;
   double yang, ydst, ya,yb,yc, ac;
   double px,py,pz;
   double a1,a2;
    
    //pos->tran.x = joints[0];
    //pos->tran.y = joints[1];
    //pos->tran.z = joints[2];
    pos->a = joints[3];
    pos->b = joints[4];
    pos->c = joints[5];
    pos->u = joints[6];
    pos->v = joints[7];
    pos->w = joints[8];

    //inputA=joints[0]-rackX; inputB=joints[1]-rackY; inputC=joints[2]-rackZ;
    inputA=joints[0]-rackX; inputB=joints[1]-rackY; inputC=-(joints[2]-rackZ);

    int p=0;
    if(0)if( fabs(inputA-lastA)>0.0001 || fabs(inputB-lastB)>0.0001 || fabs(inputC-lastC)>0.0001 )
     {
       printk("kForw(ABC->XYZ): %d,%d,%d  ",(int)(1000.0*inputA),(int)(1000.0*inputB),(int)(1000.0*inputC));
       lastA=inputA;  lastB=inputB;  lastC=inputC;
       p=1;
     }

   xdst=inputA;  // 0 to 200 mm
   ydst=inputB;  // 0 to 200 mm
   zang=inputC;  // angle from around 0, +-45 degrees.

   // first, find the position of the pivot on the z disk
   x=-zradius;
   z=0;
   xn=cos(zang*d2r)*x-sin(zang*d2r)*z;
   zn=sin(zang*d2r)*x+cos(zang*d2r)*z;
   yn=0;

   //printk("x=%d, z=%d\n",(int)(xn), (int)(zn) );

   // add on our nominal origin to center of disk
   xh=xn+xctrz;
   yh=yn+yctrz;
   zh=zn+zctrz;

   //if(DEBUG)printf("Z pivot in world (x,y,z)=%f %f %f\n",xh,yh,zh);   

   // find position of x rail
   xw=xdst+xdstx;  
   yw=xdsty;
   zw=ydsth;

   //if(DEBUG)printf("X pivot in world (x,y,z)=%f %f %f\n",xw,yw,zw);

   //if(DEBUG)printf("Y pivot in world (x,y,z)=%f %f %f\n",ydstx,ydst+ydsty,ydsth);

/*
  Note:              _
        /\
    c  / B \
      /      \ a     e
     /         \
    /A          C\
    ---------------  -
          b
   |  k |
         k = c* b^2+c^2 - a^2 / 2bc

 then subtract one from the other to get one at 0,0,0 and the other relative to it, and rotate
 Though we don't need to rotate these coards, we already know them, just use them to rotate the y pivot point.
*/

   b=sqrt( (xh-xw)*(xh-xw)+(yh-yw)*(yh-yw)+(zh-zw)*(zh-zw) );  // distance from x pivot point on rail to z pivot point by wheel
   a=zlen;  // length of z arm
   c=xlen;  // length of x arm
   //if(DEBUG)printf("a,b,c=%f %f %f\n",a,b,c);
   k=c*( b*b+c*c-a*a )/(2*b*c);
   e=sqrt( c*c-k*k );
   //if(DEBUG)printf("k=%f, e=%f\n",k,e);
   // eg. k=488.524700, e=837.999772


   rxlt[0]=xd=ydstx;   // make the y arm pivot point and a test point
   rylt[0]=yd=ydst+ydsty;
   rzlt[0]=zd=ydsth;
   rxlt[1]=xh;  // test point, z arm pivot. The rotations should leave this vector perfectly vertical, ie, all in z
   rylt[1]=yh;
   rzlt[1]=zh;
   for(i=1;i>=0;i--)
    {
      double nx,ny,nz;

      rxlt[i]-=xw;  // and make it relative to pivot of x rail
      rylt[i]-=yw;
      rzlt[i]-=zw;

      //if(DEBUG)printf("1dst[%d]=%f\n",i,dst(rxlt[i],rylt[i],rzlt[i]) );
      //if(DEBUG)printf("1c[%d]=(%f,%f,%f)\n",i,rxlt[i],rylt[i],rzlt[i]);

      //if(DEBUG)printf("atan2=%f, vs %f\n",atan2( yh-yw, zh-zw ), atan2( rylt[i], rzlt[i] ) );

      tay=-atan2( xh-xw, zh-zw );  // and rotate y arm pivot point by opposite of side "b"
      nz=cos(tay)*rzlt[i]-sin(tay)*rxlt[i];
      nx=sin(tay)*rzlt[i]+cos(tay)*rxlt[i];
      rxlt[i]=nx;  rzlt[i]=nz;

      //if(DEBUG)printf("2dst[%d]=%f (post rot y)\n",i,dst(rxlt[i],rylt[i],rzlt[i]) );
      //if(DEBUG)printf("2c[%d]=(%f,%f,%f)\n",i,rxlt[i],rylt[i],rzlt[i] );

      //if(DEBUG)printf("atan2=%f, vs %f\n",atan2( yh-yw, zh-zw ), atan2( rylt[i], rzlt[i] ) );

      if( i==1)
         tax=-atan2( rylt[1], rzlt[1] );  // and rotate again.  side "b" 
      nz=cos(tax)*rzlt[i]-sin(tax)*rylt[i];
      ny=sin(tax)*rzlt[i]+cos(tax)*rylt[i];
      rylt[i]=ny;  rzlt[i]=nz;

      //if(DEBUG)printf("3dst[%d]=%f (post rot x)\n",i,dst(rxlt[i],rylt[i],rzlt[i]) );
      //if(DEBUG)printf("3c[%d]=(%f,%f,%f)\n",i,rxlt[i],rylt[i],rzlt[i] );

      if( i==1)
         taz=-atan2( rylt[1], rxlt[1] );
      //if(DEBUG)printf("z rot ang=%f\n",taz*r2d);
      nx=cos(taz)*rxlt[i]-sin(taz)*rylt[i];
      ny=sin(taz)*rxlt[i]+cos(taz)*rylt[i];
      rxlt[i]=nx;  rylt[i]=ny;

      //if(DEBUG)printf("4dst[%d]=%f (post rot z)\n",i,dst(rxlt[i],rylt[i],rzlt[i]) );
      //if(DEBUG)printf("4c[%d]=(%f,%f,%f)\n",i,rxlt[i],rylt[i],rzlt[i] );
   } // rotate z pivot into pure z coord, and use those rotations to rotate y pivot piont.

   //if(DEBUG)printf("[1]=%f, %f, %f\n",rxlt[1],rylt[1],rzlt[1] );
   //if( fabs(rxlt[1])>0.000001 || fabs(rylt[1])>0.000001 )
   //   fprintf(stderr,"Test point out of bounds, %f, %f\n",rxlt[1],rylt[1]);

   rxl=rxlt[0], ryl=rylt[0], rzl=rzlt[0];

   // confirm point at k,e is a xlen from 0,0,0 and zlen from r*lt[1]
   //if(DEBUG)printf("xlen calced dist=%f\n",dst( 0-k, 0-e, 0-0) );
   //if(DEBUG)printf("zlen calced dist=%f\n",dst( e-rxlt[1], 0-rylt[1], k-rzlt[1] ) );

   // ylen; k; rzl;
   // rxl; ryl;
   //if(DEBUG)printf("ylen=%f , k=%f rzl=%f\n",ylen,k,rzl);

   // (cos(a)*e-rxl)^2 + (sin(a)*e-ryl)^2 = ylen^2 - rzl^2

   // find a
   // (cos(a)*e-rxl)^2 + (sin(a)*e-ryl)^2 = ylen^2 - (k-rzl)^2
   // Which is a circle of radius e at (rxl,ryl), intersecting (or not)
   // a circle at (0,0) of radius sqrt(RHS). Generally has two intersections.
   // Can apply cosine rule again, if we also add on angle from triangle to "world", yang.
   yang=atan2(ryl,rxl);
   ydst=sqrt(ryl*ryl+rxl*rxl );
   ya=sqrt( ylen*ylen - (k-rzl)*(k-rzl) );   // this can presumably be negative
   yb=e;
   yc=ydst;
   ac=acos( (yb*yb+yc*yc-ya*ya) / (2*yb*yc) );
   //if(DEBUG)printf("ya=%f, yb=%f, yc=%f\n",ya,yb,yc);
   //if(DEBUG)printf("yang(pre)=%f, ac=%f\n",yang*r2d,ac*r2d);

   yang-=ac;
   
   px=cos(yang)*e;  py=sin(yang)*e;  pz=k; // compute the tool point in space rel to x pivot

   // confirm again:
   //if(DEBUG)printf("xlen calced dist=%f\n",dst( 0-px, 0-py, 0-pz) );
   //if(DEBUG)printf("zlen calced dist=%f\n",dst( rxlt[1]-px, rylt[1]-py, rzlt[1]-pz ) );
   //if(DEBUG)printf("ylen calced dist=%f\n",dst( rxlt[0]-px, rylt[0]-py, rzlt[0]-pz ) );

   rxlt[2]=px;  rylt[2]=py;  rzlt[2]=pz;

   // unrotate [0](y actuator positon), [1](z actuator position), [2](tool point)
   // , currently all relative to x actuator position (or x actuator pivot).
   for(i=0;i<=2;i++)
    {
      double nx,ny,nz;

      px=rxlt[i];  py=rylt[i];  pz=rzlt[i];

      //if(DEBUG)printf("[%d] in x,y,z=%f,%f,%f (using ang %f)\n",i,px,py,pz,-taz*r2d);

      // unrotate around z axis
      nx=cos(-taz)*px-sin(-taz)*py;
      ny=sin(-taz)*px+cos(-taz)*py;
      px=nx;  py=ny;

      //if(DEBUG)printf("[%d] Z x,y,z=%f,%f,%f (using ang %f)\n",i,px,py,pz,-taz*r2d);

      // unrotate around x axis
      nz=cos(-tax)*pz-sin(-tax)*py;
      ny=sin(-tax)*pz+cos(-tax)*py;
      py=ny;  pz=nz;

      //if(DEBUG)printf("[%d] X x,y,z=%f,%f,%f\n",i,px,py,pz);

      // unrotate around y axis
      nz=cos(-tay)*pz-sin(-tay)*px;
      nx=sin(-tay)*pz+cos(-tay)*px;
      px=nx;  pz=nz;

      //if(DEBUG)printf("[%d] Y x,y,z=%f,%f,%f\n",i,px,py,pz);

      // make relative to world instead of x pivot point.
      px+=xw;  py+=yw;  pz+=zw;   

      //if(DEBUG)printf("[%d] ou x,y,z=%f,%f,%f\n",i,px,py,pz);

      rxlt[i]=px;  rylt[i]=py;  rzlt[i]=pz;
   }

   //if(DEBUG)printf("xlen calced dist=%f\n",dst( xw-rxlt[2], yw-rylt[2], zw-rzlt[2] ) );
   //if(DEBUG)printf("zlen calced dist=%f\n",dst( rxlt[1]-rxlt[2], rylt[1]-rylt[2], rzlt[1]-rzlt[2] ) );
   //if(DEBUG)printf("ylen calced dist=%f\n",dst( rxlt[0]-rxlt[2], rylt[0]-rylt[2], rzlt[0]-rzlt[2] ) );

   //if(DEBUG)printf("Are unrotated pivots in the same place as they started? (orig vs calced)\n" );
   //if(DEBUG)printf("Z %f,%f,%f vs %f,%f,%f\n",xh,yh,zh,rxlt[1],rylt[1],rzlt[1] );
   //if(DEBUG)printf("Y %f,%f,%f vs %f,%f,%f\n",xd,yd,zd,rxlt[0],rylt[0],rzlt[0] );

   //if(DEBUG)printf("point x,y,z=%f, %f, %f\n",px,py,pz);
   //if(DEBUG)printf("Yactu x,y,z=%f, %f, %f\n",xd,yd,zd);

   //a1=dst(px-xw,py-yw,pz-zw);  a2=xlen;
   //if( fabs(a1-a2) >0.000001 )
   //   fprintf(stderr,"Warning: x out of bounds, %f vs %f\n",a1,a2);
   //a1=dst(px-xd,py-yd,pz-zd);  a2=ylen;
   //if( fabs(a1-a2) >0.000001 )
   //   fprintf(stderr,"Warning: y out of bounds, %f vs %f\n",a1,a2);
   //a1=dst(px-xh,py-yh,pz-zh);  a2=zlen;
   //if( fabs(a1-a2) >0.000001 )
   //   fprintf(stderr,"Warning: z out of bounds, %f vs %f\n",a1,a2);
   //printf("dists=%f, %f, %f\n",dst(px-xw,py-yw,pz-zw), dst(px-xd,py-yd,pz-zd), dst(px-xh,py-yh,pz-zh) );

   //outputX=px;  outputY=py;  outputZ=pz;

   //if(px==NAN || py==NAN || pz==NAN )
   //   fprintf(stderr,"Warning: inverse kinematics says solution is impossible for a=%f,b=%f,c=%f\n",inputA,inputB,inputC);

   //printk("f(): %d,%d,%d\n",(int)(px-xh),(int)(py-yh),(int)(pz-zh) );

   px+=toolX;   py+=toolY;  pz+=toolZ;

   pos->tran.x=px;
   pos->tran.y=py;
   pos->tran.z=pz;

   if(p==1)
      printk("  %d,%d,%d\n",(int)(1000.0*px),(int)(1000.0*py),(int)(1000.0*pz));


    return 0;
}

int kinematicsInverse(const EmcPose * pos,
              double *joints,
              const KINEMATICS_INVERSE_FLAGS * iflags,
              KINEMATICS_FORWARD_FLAGS * fflags)
{

   double inputX=pos->tran.x-toolX, inputY=pos->tran.y-toolY, inputZ=pos->tran.z-toolZ;
   double outputA, outputB, outputC;
   double x,y,z;
   double tax, tay, taz, nx, ny, nz;
   double zang, a,b,c, ac;

   //joints[0] = pos->tran.x - pos->tran.y*(*(haldata->skew));
   //joints[1] = pos->tran.y;
   //joints[2] = pos->tran.z;
   joints[3] = pos->a;
   joints[4] = pos->b;
   joints[5] = pos->c;
   joints[6] = pos->u;
   joints[7] = pos->v;
   joints[8] = pos->w;

   // find X axis pivot point. Rotate sphere that is tool position so that Xaxis=0 is at 0,0,0

   x=inputX;
   y=inputY;
   z=inputZ;

   int p=0;
   if(0)if( fabs(lastX-x)>0.0001 || fabs(lastY-y)>0.0001 || fabs(lastZ-z)>0.0001 )
    {
      printk("kInve (XYZ->ABC): %d,%d,%d  ",(int)((inputX+toolX)*1000.0),(int)((inputY+toolY)*1000.0),(int)((inputZ+toolZ)*1000.0));
      lastX=x;  lastY=y;  lastZ=z;
      p=1;
    }

   x-=xdstx;  // make relative to Xaxis=0
   y-=xdsty;
   z-=xdsth;

   // rotate x axis to make z=0
   tax=-atan2(z,y);
   ny=cos(tax)*y-sin(tax)*z;
   nz=sin(tax)*y+cos(tax)*z;
   y=ny;  z=nz;

   // only one rotation, so woudln't cover the case where x and y pivot actuators are not perpen to each other.

   //if(fabs(z)>0.000001)
   //   fprintf(stderr,"Warning: z not zero\n");

   //printf("x=%f, y=%f, z=%f\n",x,y,z);

   // makes a right angle triangle, where y is the minor length and ylen is the hyp, so
   // sqrt( xlen^2 - y^2 ) is the x distance from the circle crossover point (where actuator A)
   // needs to be  to  the x=0 line of the circle. As the x=0 line of the circle is x,
   x-=sqrt( xlen*xlen - y*y );  // is distance from x axis (start of) pivot point.
   //printf("x=%f\n",x);
   // and so
   outputA=x; //-xdstx;   // is the distance from the start of the pivot point (ie, end of travel) to the pivot point itself.
               //        // which is the actuator distance.
               // no, is already correct, don't understand why.

   //printf("x arc=%f\n",sqrt( xlen*xlen - y*y ) );

   //printf("A=%f\n",outputA);

   // find Y axis pivot point. Rotate sphere that is tool position so that Yaxis=0 is at 0,0,0
   x=inputX;
   y=inputY;
   z=inputZ;

   //if(DEBUG)printf("input x=%f, y=%f, z=%f\n",x,y,z);

   x-=ydstx;  // make relative to Yaxis=0
   y-=ydsty;
   z-=ydsth;

   //if(DEBUG)printf("offset x=%f, y=%f, z=%f\n",x,y,z);

   // rotate y axis to make z=0
   tay=-atan2(z,x);
   nx=cos(tay)*x-sin(tay)*z;
   nz=sin(tay)*x+cos(tay)*z;
   x=nx;  z=nz;

   // only one rotation, so wouldn't cover the case where x and y actuators are not perpen to each other.

   //if(fabs(z)>0.000001)
   //   fprintf(stderr,"Warning: z not zero\n");

   //if(DEBUG)printf("rotated x=%f, y=%f, z=%f\n",x,y,z);

   //if(DEBUG)printf("ylen=%f, sqrt()=%f\n",ylen, sqrt( ylen*ylen - x*x ) );

   if( y>0 )
    {
      y=fabs(y);

      y-=sqrt( ylen*ylen - x*x );
      //if(DEBUG)printf("+y=%f\n",y);
    }
   else
    {
      y=fabs(y);

      y-=sqrt( ylen*ylen - x*x );
      //if(DEBUG)printf("-y=%f\n",y);
    }

   outputB=-y;

   //if(DEBUG)printf("outputB dst()=%f\n",dst(inputX-0,inputY-(ydsty+outputB),inputZ-ydsth ) );

   //if(DEBUG)printf("Y pivot=%f %f %f\n",0,ydsty+outputB,ydsth );

   // find Z axis pivot point, which is really an angle. 
   // Rotate sphere that is tool position so that pivot point of Z (the fixed pivot, not the pivot on the wheel) at 0,0,0
   // and the rotating disk is purely in in x+z plane - which it is by default for the current Z position.
   x=inputX;
   y=inputY;
   z=inputZ;

   //printf("input x=%f, y=%f, z=%f\n",x,y,z);

   x-=xctrz;  // make relative to Zaxis fixed pivot point.
   y-=yctrz;
   z-=zctrz;

   //printf("offset x=%f, y=%f, z=%f\n",x,y,z);

   // rather than rotate, just use the distances or compute the distances to apply the cosine rule.
   zang=-atan2( z,x );
   b=sqrt( x*x+z*z ); 
   a=sqrt( zlen*zlen - y*y );   // z arm length
   c=zradius;   // z wheel radius
   ac=acos( (b*b+c*c-a*a) / (2*b*c) );

   //printf("a=%f, b=%f, c=%f\n",a,b,c);

   //printf "ac=%f, zang=%f\n",ac*r2d,zang*r2d);

   outputC=-( (zang+ac)*r2d-180.0 );

   //if(outputA==NAN || outputB==NAN || outputC==NAN )
   //   fprintf(stderr,"Warning: forward kinematics says solution is impossible for x=%f,y=%f,z=%f\n",inputX,inputY,inputZ);

   //joints[0]=outputA+rackX; joints[1]=outputB+rackY; joints[2]=outputC+rackZ;
   joints[0]=outputA+rackX; joints[1]=outputB+rackY; joints[2]=-(outputC+rackZ);

   if( p==1 )
      printk("  %d,%d,%d\n",(int)(outputA*1000.0),(int)(-outputB*1000.0),(int)(outputC*1000.0));

   return 0;
}

/* implemented for these kinematics as giving joints preference */
int kinematicsHome(EmcPose * world,
           double *joint,
           KINEMATICS_FORWARD_FLAGS * fflags,
           KINEMATICS_INVERSE_FLAGS * iflags)
{
    *fflags = 0;
    *iflags = 0;

    return kinematicsForward(joint, world, fflags, iflags);
}

KINEMATICS_TYPE kinematicsType()
{
    return KINEMATICS_BOTH;
}

EXPORT_SYMBOL(kinematicsType);
EXPORT_SYMBOL(kinematicsForward);
EXPORT_SYMBOL(kinematicsInverse);
MODULE_LICENSE("GPL");

int comp_id;
int rtapi_app_main(void) {
    int res = 0;
    comp_id = hal_init("millkins");
    if(comp_id < 0) return comp_id;

    do {
      haldata = hal_malloc(sizeof(struct haldata));
      if(!haldata) break;

      //res = hal_pin_float_new("millkins.skew", HAL_IN, &(haldata->skew), comp_id);
      //if (res < 0) break;

      hal_ready(comp_id);
      return 0;
    } while (0);

    hal_exit(comp_id);
    return comp_id;
}

void rtapi_app_exit(void)
{
  hal_exit(comp_id);
}

#endif
