/**********************************************************************
 * rsa.c                                                    August 2005
 *
 * ASYM: An implementation of Asymetric Cryptography in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#ifdef __KERNEL__
#include <linux/types.h>
#include <linux/random.h>
#else
#include "compat.h"
#endif


#include "rsa.h"
#include "pk.h"
#include "test_util.h"

#define unsxLen (bitLen / unsx_Bit_Length)
#define RETURN(RET) do { ret = (RET); goto rsa_keygen_ret; }while(1)

int rsa_keygen(rsa_key_t *pub, rsa_key_t *pri, int bitLen, int genOptions) 
{
  int ret = PK_OK;
  int pLen;
  UNSX_NEW(p, unsxLen);
  UNSX_NEW(q, unsxLen);
  UNSX_NEW(thi, unsxLen+4);
  UNSX_NEW(n, unsxLen);
  UNSX_NEW(e, unsxLen);
  UNSX_NEW(d, unsxLen);
  UNSX_NEW(tmp, unsxLen*2);

  if(!p || !q || !thi || !n || !e || !d || !tmp) {
    RETURN(-ENOMEM);
  }

#if 0
  unsx_setLow(p, 15, unsxLen);
  unsx_setLow(q, 26, unsxLen);
  unsx_modInv(tmp, p, q, unsxLen);

  printf("p = "); for (i=0; i<unsxLen; i++) printf("%.8x ", p[i]); printf("\n");
  printf("q = "); for (i=0; i<unsxLen; i++) printf("%.8x ", q[i]); printf("\n");
  printf("tmp = "); for (i=0; i<unsxLen; i++) printf("%.8x ", tmp[i]); printf("\n");
#endif
  unsx_setZero(tmp, unsxLen*2);
  
  pub->len = pri->len = unsxLen;
  pub->n = (unsx*)kmalloc(pub->len * sizeof(unsx), GFP_KERNEL);
  pub->e = (unsx*)kmalloc(pub->len * sizeof(unsx), GFP_KERNEL);
  pub->d = NULL;
  pri->n = (unsx*)kmalloc(pri->len * sizeof(unsx), GFP_KERNEL);
  pri->e = (unsx*)kmalloc(pri->len * sizeof(unsx), GFP_KERNEL);
  pri->d = (unsx*)kmalloc(pri->len * sizeof(unsx), GFP_KERNEL);

  if(!pub->n || !pub->e || !pri->n || !pri->e || !pri->d) {
    RETURN(-ENOMEM);
  }

  unsx_setLow(e, 65537, unsxLen);
  unsx_setZeroDual(p, q, unsxLen);
  unsx_setZeroDual(d, thi, unsxLen);

  /* I don't think we need do do this in the Kernel as
   * get_random_bytes() doesn't need seeding (I think) - NTT COMWARE */
#if 0
  {
    time_t x;
    do_gettimeofday(&x);
    srand(x);
  }
#endif

  get_random_bytes(p, unsxLen*sizeof(unsx)/2);
  get_random_bytes(q, unsxLen*sizeof(unsx)/2);

  ret = unsx_nextPrime(q, unsxLen);
  if(ret < 0)
    RETURN(ret);
  ret = unsx_nextPrime(p, unsxLen);
  if(ret < 0)
    RETURN(ret);

  pLen = unsx_countArr(p,unsxLen);

  ret = unsx_isPrimePRP(q,unsxLen);
  if(ret < 0)
    RETURN(ret);
  if (!ret)
    RETURN(PK_KEYGEN_FAILED);

  ret = unsx_isPrimePRP(p,unsxLen);
  if(ret < 0)
    RETURN(ret);
  if (!ret)
    RETURN(PK_KEYGEN_FAILED);

  ret = unsx_isPrimeSPRP(q,unsxLen);
  if(ret < 0)
    RETURN(ret);
  if (!ret)
    RETURN(PK_KEYGEN_FAILED);

  ret = unsx_isPrimeSPRP(p,unsxLen);
  if(ret < 0)
    RETURN(ret);
  if (!ret)
    RETURN(PK_KEYGEN_FAILED);

  ret = unsx_mult(n, p, q, pLen);              /* n = pq */
  if(ret < 0)
    RETURN(ret);
  unsx_dec(q, q, pLen);                        /* q = q-1 */
  unsx_dec(p, p, pLen);                        /* p = p-1 */
  ret = unsx_mult(thi, p, q, pLen);            /* thi = (p-1)(q-1) */
  if(ret < 0)
    RETURN(ret);
  ret = unsx_modInv(d, e, thi, unsxLen);       /* d * e == 1 (mod thi) */
  if(ret < 0)
    RETURN(ret);
  
  unsx_set(pub->n, n, unsxLen);
  unsx_set(pub->e, e, unsxLen);
  unsx_set(pri->n, n, unsxLen);
  unsx_set(pri->e, e, unsxLen);
  unsx_set(pri->d, d, unsxLen);

rsa_keygen_ret:
  if(p)
    UNSX_FREE(p, unsxLen);
  if(q)
    UNSX_FREE(q, unsxLen);
  if(thi)
    UNSX_FREE(thi, unsxLen);
  if(n)
    UNSX_FREE(n, unsxLen);
  if(e)
    UNSX_FREE(e, unsxLen);
  if(d)
    UNSX_FREE(d, unsxLen);
  if(tmp)
    UNSX_FREE(tmp, unsxLen);
  if(ret) {
    rsa_key_destroy_data(pub);
    rsa_key_destroy_data(pri);
  }

  return ret;
}
#undef unsxLen
#undef RETURN

#define RETURN(RET)\
do {\
  ret = RET;\
  goto rsa_wrap_ret;\
} while(0)
int rsa_wrap(char *out, int *outLen,
                    const rsa_key_t *pub, const unsx *exp,
                    const char *in, int inLen,
                    int wrapType) {
  unsigned char *inBuf = NULL;
  UNSX_NEW(IN, UNSX_LENGTH);
  UNSX_NEW(OUT, UNSX_LENGTH);
  int pad;
  int ret=PK_OK;
  int bytes;

  printk(KERN_DEBUG "rsa_wrap enter\n");

  bytes = UNSX_LENGTH * sizeof(unsx);
  if(*outLen < bytes)
	  RETURN(PK_INVALID_INPUT);
  inBuf = (char*)kmalloc(bytes, GFP_KERNEL);
  if(!inBuf) {
	  RETURN(-ENOMEM);
  }

  if (exp == NULL) {
    RETURN(PK_INVALID_KEY);
  }
  
  memset(inBuf, 0xff, UNSX_LENGTH * sizeof(unsx));

  *inBuf = 0;
  *(inBuf+1) = wrapType;
  pad = bytes - 3 - inLen;
  if(wrapType == PK_PKCS1_ENCRYPT) {
    unsigned char *p;
    unsigned char *q;
    if(pad < 8) {
      printk(KERN_DEBUG "rsa_wrap: boom %d - 3 - %d!\n", bytes, inLen);
      RETURN(PK_INVALID_INPUT);
    }
    get_random_bytes(inBuf+2, pad);
    /* All pad bytes must be non-zero */
    for(p = inBuf + 2, q = inBuf + 2 + pad ; p < q; p++) {
      if(!*p) { 
	      get_random_bytes(p, 1);
	      p--;
      }
    }
  }
  *(inBuf+pad+2) = 0;
  memcpy(inBuf+pad+3, in, inLen);
  
  /* unsx_setZero(IN, UNSX_LENGTH); Don't need this - NTT COMWARE */
  load_unsx_rev(IN, inBuf, UNSX_LENGTH);

  /*
  asym_print_char(KERN_DEBUG "rsa_wrap: inBuf", inBuf, 
  			UNSX_LENGTH * sizeof(unsx));
  asym_print_unsx(KERN_DEBUG "rsa_wrap: IN", IN, UNSX_LENGTH);
  */
  ret = unsx_modPow(OUT, IN, exp, pub->n, pub->len);
  if(ret < 0)
    RETURN(ret);
  unload_unsx_rev(out, OUT, UNSX_LENGTH);
  *outLen = bytes;

  /*
  asym_print_unsx(KERN_DEBUG "rsa_wrap: OUT", OUT, UNSX_LENGTH);
  asym_print_char(KERN_DEBUG "rsa_wrap: out", out, UNSX_LENGTH * sizeof(unsx));
  */

rsa_wrap_ret:
  if(IN)
    UNSX_FREE(IN, UNSX_LENGTH);
  if(OUT)
    UNSX_FREE(OUT, UNSX_LENGTH);
  if(inBuf) {
    memset(inBuf, 0, bytes);
    kfree(inBuf);
  }

  return ret;
}
#undef RETURN

#define RETURN(RET)\
do {\
  ret = RET;\
  goto rsa_unwrap_ret;\
} while(0)
int rsa_unwrap(char *out, int *outLen,
                      const rsa_key_t *pri, const unsx *exp,
                      const char *in, int inLen, int wrapType) 
{
  unsigned char *tmpBuf = NULL;
  unsigned char *p;
  UNSX_NEW(IN, UNSX_LENGTH);
  UNSX_NEW(OUT, UNSX_LENGTH);
  int i;
  int bytes;
  int ret=PK_OK;

  bytes = UNSX_LENGTH * sizeof(unsx);

  if(!IN || !OUT || !exp) {
    printk(KERN_DEBUG "rsa_unwrap: no IN, OUT or exp\n");
    RETURN(-ENOMEM);
  }

  tmpBuf = (char*)kmalloc(bytes, GFP_KERNEL);
  if(!tmpBuf) {
    printk(KERN_DEBUG "rsa_unwrap: kmalloc\n");
    RETURN(-ENOMEM);
  }

  unsx_setZero(IN, UNSX_LENGTH);
  /* XXX: Is bytes - inLen correct ? Should this always be 0? NTT COMWARE */
  load_unsx_rev(IN + bytes - inLen, in, inLen/sizeof(unsx));

#if 0 /* DEBUG */
  asym_print_unsx(KERN_DEBUG "rsa_unwrap: IN", IN, UNSX_LENGTH);
#endif
  ret = unsx_modPow(OUT, IN, exp, pri->n, pri->len);
  if(ret < 0) {
    printk(KERN_DEBUG "rsa_unwrap modulus failed\n");
    RETURN(ret);
  }
  unload_unsx_rev(tmpBuf, OUT, UNSX_LENGTH); 

#if 0 /* DEBUG */
  asym_print_unsx(KERN_DEBUG "rsa_unwrap: OUT", OUT, UNSX_LENGTH);
  asym_print_char(KERN_DEBUG "rsa_unwrap: tmpBuf", tmpBuf, bytes);
#endif

  p = tmpBuf;
  if ((*p++ != 0) || (*p++ != wrapType)) {
    printk(KERN_DEBUG "rsa_unwrap incorrect pkcs header %d\n", wrapType);
    RETURN(PK_UNWRAP_FAILED);
  }

  for (i=2; i<bytes; i++) {
    if (!*p) {
      break;
    }
    if (wrapType == PK_PKCS1_SIGN && *p != 0xff) {
      printk(KERN_DEBUG "rsa_unwrap incorrect pad (sign)\n");
      RETURN(PK_UNWRAP_FAILED);
    }
    p++;
  }
  if(i == bytes) {
    printk(KERN_DEBUG "rsa_unwrap entire message is pad\n");
    RETURN(PK_UNWRAP_FAILED);
  }

  if(*outLen < bytes-1-i) {
          printk(KERN_DEBUG "rsa_unwrap outLen too short for message: "
			  "%d < %d\n", *outLen, bytes-1-i);
	  RETURN(PK_INVALID_INPUT);
  }
  memcpy(out, p+1, bytes-1-i);

  *outLen = bytes-1-i;

  ret = 0;
rsa_unwrap_ret:
  if(IN)
    UNSX_FREE(IN, UNSX_LENGTH);
  if(OUT)
    UNSX_FREE(OUT, UNSX_LENGTH);
  if(tmpBuf) {
    memset(tmpBuf, 0, bytes);
    kfree(tmpBuf);
  }
  return ret;
}
