/**********************************************************************
 * core.c                                                   August 2005
 *
 * TCPS: TCP Splicing Module
 * This code is based on the source code of ipvs-1.0.9 and tcpsp.
 *
 * IPVS         An implementation of the IP virtual server support for the
 *              LINUX operating system.  IPVS is now implemented as a module
 *              over the Netfilter framework. IPVS can be used to build a
 *              high-performance and highly available server based on a
 *              cluster of servers.
 *
 * Authors:     Wensong Zhang <wensong@linuxvirtualserver.org>
 *              Peter Kese <peter.kese@ijs.si>
 *              Julian Anastasov <ja@ssi.bg>
 *
 * The IPVS code for kernel 2.2 was done by Wensong Zhang and Peter Kese,
 * with changes/fixes from Julian Anastasov, Lars Marowsky-Bree, Horms
 * and others. Many code here is taken from IP MASQ code of kernel 2.2.
 *
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 *
 * Changes:
 *     Hirotaka Sasaki <hiro1967@mti.biglobe.ne.jp>:
 *         - Supported Linux-2.6.
 *         - Rewrote to remove tcps module.
 *         - Rewrote to work without kernel patch.
 *
 **********************************************************************/

#include <linux/config.h>
#include <linux/version.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/init.h>
#include <linux/types.h>
#include <linux/errno.h>
#include <linux/fs.h>
#include <linux/sysctl.h>
#include <linux/proc_fs.h>
#include <linux/timer.h>
#include <linux/swap.h>
#include <linux/proc_fs.h>
#include <linux/file.h>
#include <linux/skbuff.h>               /* for struct sk_buff */
#include <linux/ip.h>                   /* for struct iphdr */
#include <net/tcp.h>                    /* for csum_tcpudp_magic */
#include <net/udp.h>
#include <net/icmp.h>                   /* for icmp_send */

#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>

#include <net/ip.h>
#include <net/sock.h>

#include <asm/uaccess.h>

#include "tcps.h"
#include "tcps_compat.h"

union tcps_tphdr {
	unsigned char *raw;
	struct udphdr *uh;
	struct tcphdr *th;
	struct icmphdr *icmph;
	__u16 *portp;
};

static inline u16 tcps_check_diff(u32 old, u32 new, u16 oldsum)
{
	u32 diff[2] = { old, new };

	return csum_fold(csum_partial((char *) diff, sizeof(diff),
				      oldsum ^ 0xFFFF));
}

static inline int
tcps_skb_cow(struct sk_buff *skb, unsigned int headroom,
		struct iphdr **iph_p, unsigned char **t_p)
{
	int delta = (headroom > 16 ? headroom : 16) - skb_headroom(skb);

	if (delta < 0)
		delta = 0;

	if (delta ||skb_cloned(skb)) {
		if (pskb_expand_head(skb, (delta+15)&~15, 0, GFP_ATOMIC))
			return -ENOMEM;

		/* skb data changed, update pointers */
		*iph_p = skb->nh.iph;
		*t_p = (char*) (*iph_p) + (*iph_p)->ihl * 4;
	}
	return 0;
}

/*
 * parse timestamp, generic routine.
 * We don't have to care options that should appear only in SYN packets
 * because we splice established connections.
 */
static int
tcps_parse_timestamps(struct tcphdr *th, __u32 **ts_p)
{
	unsigned char *p;
	int len = (th->doff * 4) - sizeof(struct tcphdr);
	int found = 0;

	p = (unsigned char *)(th + 1);

	while(len > 0) {
	  	int opt = *p++;
		int olen;

		switch (opt) {
		case TCPOPT_EOL:
			return found;

		case TCPOPT_NOP:
			len--;
			continue;
		}

		olen = *p++;
		if (olen < 2 || olen > len) {
			return found;
		}

		switch (opt) {
		case TCPOPT_TIMESTAMP:
			if (olen == TCPOLEN_TIMESTAMP) {
				*ts_p = (__u32 *)p;
				found = 1;
			}
	  	}
		p += olen - 2;
		len -= olen;
	}
	return found;
}

/* parse timestamp, most common case */
static inline int
tcps_fast_parse_timestamps(struct tcphdr *th, __u32 **ts_p)
{
	if (th->doff == sizeof(struct tcphdr)>>2) {
		return 0;
	}

	if (th->doff ==
	    (sizeof(struct tcphdr)>>2)+(TCPOLEN_TSTAMP_ALIGNED>>2)) {
		__u32 *ptr = (__u32 *)(th + 1);
		if (*ptr == __constant_ntohl((TCPOPT_NOP << 24) |
					     (TCPOPT_NOP << 16) |
					     (TCPOPT_TIMESTAMP << 8) |
					     TCPOLEN_TIMESTAMP)) {
			++ptr;
			*ts_p = ptr;
			return 1;
		}
	}

	return tcps_parse_timestamps(th, ts_p);
}

static void
tcps_timestamps_to_nop(struct tcphdr *th)
{
	unsigned char *p;
	int len = (th->doff * 4) - sizeof(struct tcphdr);

	p = (unsigned char *)(th + 1);

	while(len > 0) {
	  	int opt = *p;
		int olen;

		switch (opt) {
		case TCPOPT_EOL:
			return;

		case TCPOPT_NOP:
			p++;
			len--;
			continue;
		}

		olen = *(p+1);
		if (olen < 2 || olen > len) {
			return;
		}

		switch (opt) {
		case TCPOPT_TIMESTAMP:
			if (olen == TCPOLEN_TIMESTAMP) {
				memset(p, TCPOPT_NOP, TCPOLEN_TIMESTAMP);
			}
			return;
	  	}
		p += olen;
		len -= olen;
	}
	return;
}

void
tcps_reset_sock(struct sock *sk)
{
	TCPS_DBG("tcps_reset_sock: sk=%p\n", sk);

	if (TCPS_SK_STATE(sk) != TCP_CLOSE) {
		TCPS_SK_ERR(sk) = ECONNRESET;
		if (!TCPS_SK_FL_DEAD(sk)) {
			TCPS_SK_ERR_REPORT(sk);
		}

		tcp_done(sk);
	}
}

static struct sock *
tcps_sock_get(int fd, int *err)
{
	struct file *file;
	struct inode *inode;
	struct socket *sock;

	if (!(file = fget(fd))) {
		*err = -EBADF;
		return NULL;
	}

	inode = file->f_dentry->d_inode;
	if (!inode->i_sock || !(sock = SOCKET_I(inode))) {
		*err = -ENOTSOCK;
		fput(file);
		return NULL;
	}

	if (sock->file != file) {
		TCPS_ERR("tcps_sock_get: socket file changed!\n");
		sock->file = file;
	}

	sock_hold(sock->sk);
	fput(file);
	return sock->sk;
}

void
tcps_conn_init_tcpopt(struct tcps_conn *tc, struct sock *sk, int which)
{
	struct tcp_opt *tp;
	struct tcps_conn_tcpopt *tcopt;

	tp = tcp_sk(sk);
	tcopt = &tc->tcpopt[which];

	tcopt->sseq = tp->snd_nxt;
	tcopt->rseq = tp->rcv_nxt;

	/* keep timestamps (option...) */
	tcopt->ts_ok = tp->tstamp_ok;
	tcopt->tsv = tp->rcv_tsval;
	tcopt->tse = tp->rcv_tsecr;

	tcopt->csseq = tp->snd_nxt;
	tcopt->lack = tp->rcv_nxt;
	tcopt->ltse = tp->rcv_tsval;
}

static int
tcps_splice(struct sock *sk1, struct sock *sk2)
{
	struct tcp_opt *tp1, *tp2;
	struct tcps_conn *tc;
	struct tcps_info *tcpsi;

	if (TCPS_SK_PROTOCOL(sk1) != TCPS_SK_PROTOCOL(sk2)
	    || TCPS_SK_PROTOCOL(sk1) != IPPROTO_TCP) {
		TCPS_DBG("tcps_splice: protocol mismatch\n");
		return -EINVAL;
	}

	tcpsi = tcps_info_get(sk2);
    if (!tcpsi) {
		TCPS_DBG("tcps_splice: not found info: sk2=%p \n", sk2);
		return -EINVAL;
	} else if (!(tcpsi->flags & TCPS_INFO_F_SOURCE_FAKE)) {
		TCPS_DBG("tcps_splice: info flags error: sk2=%p tcpsi=%p\n",
			sk2, tcpsi);
		tcps_info_put(sk2);
		return -EINVAL;
    }
	spin_lock_bh(&tcpsi->lock);
	tcpsi->flags |= TCPS_INFO_F_SPLICED;
	spin_unlock_bh(&tcpsi->lock);
	tcps_info_put(sk2);

	tp1 = tcp_sk(sk1);
	tp2 = tcp_sk(sk2);

	tc = tcps_conn_preroute_get(TCPS_SK_DADDR(sk1), TCPS_SK_DPORT(sk1),
				    TCPS_SK_DADDR(sk2), TCPS_SK_DPORT(sk2));
	if (tc == NULL) {
		TCPS_DBG("tcps_splice: connection entry not found\n");
		TCPS_DBG("sk1=%08x:%d sk2=%08x:%d\n",
			 htonl(TCPS_SK_DADDR(sk1)), htons(TCPS_SK_DPORT(sk1)),
			 htonl(TCPS_SK_DADDR(sk2)), htons(TCPS_SK_DPORT(sk2)));
		return -EINVAL;
	}

	/* now we can splice sockets */
	tc->vaddr = TCPS_SK_RCVADDR(sk1);
	tc->vport = TCPS_SK_SPORT(sk1);

	/* store tcp connection parameters */
	tcps_conn_init_tcpopt(tc, sk1, TCPS_CONN_TCPOPT_CL);
	tcps_conn_init_tcpopt(tc, sk2, TCPS_CONN_TCPOPT_RS);

	TCPS_DBG("splice: client: (seq %u/ack %u)  server: (seq %u/ack %u)\n",
		 tc->tcpopt[0].sseq, 
		 tc->tcpopt[0].rseq, 
		 tc->tcpopt[1].sseq, 
		 tc->tcpopt[1].rseq);
	/* XXX TODO: other tcp options (sack, wscale, etc...) */

	tc->state |= TCPS_CONN_S_SPLICED;
	tc->csk = sk1;
	tc->rsk = sk2;
	tcps_conn_put(tc);
	return 0;
}

static int
tcps_setsockopt(struct sock *sk, int cmd, void *user, unsigned int len)
{
	struct sockaddr_in sin;
	struct tcps_tcpopt tcpopt;
	struct tcps_info *tcpsi;
	struct sock *sk2;
	int spfd;
	int ret = 0;

	switch (cmd) {
	case TCPS_SO_SET_SOURCE_FAKE:
		if (len != sizeof(sin)) {
			TCPS_DBG("tcps_setsockopt: size mismatch\n");
			ret = -EINVAL;
			break;
		}

		if (copy_from_user(&sin, user, len) != 0) {
			ret = -EFAULT;
			break;
		}

		tcpsi= tcps_info_get(sk);
		if (!tcpsi) {
			tcpsi = tcps_info_new(sk);
			if (IS_ERR(tcpsi)) {
				ret = PTR_ERR(tcpsi);
				break;
			}
		}

		spin_lock_bh(&tcpsi->lock);
		tcpsi->saddr = sin.sin_addr.s_addr;
		tcpsi->sport = sin.sin_port;
		tcpsi->flags |= TCPS_INFO_F_SOURCE_FAKE;
		spin_unlock_bh(&tcpsi->lock);

		tcps_info_put(sk);
		break;

	case TCPS_SO_SPLICE:
		if (len != sizeof(int)) {
			TCPS_DBG("tcps_setsockopt(SPLICE): size mismatch\n");
			ret = -EINVAL;
			break;
		}

		if (copy_from_user(&spfd, user, len) != 0) {
			TCPS_DBG("tcps_setsockopt(SPLICE): copyin failed\n");
			ret = -EFAULT;
			break;
		}

		sk2 = tcps_sock_get(spfd, &ret);
		if (sk2 == NULL) {
			ret = -EINVAL;
			break;
		}

		sock_hold(sk);
		ret = tcps_splice(sk, sk2);
		if (ret) {
			sock_put(sk);
			sock_put(sk2);
			break;
		}

		TCPS_DBG("tcps_setsockopt(SPLICE): splicing succeed : sk=%p sk2=%p\n",
			 sk, sk2);

		break;

	case TCPS_SO_SET_TCPOPT:
		if (len != sizeof(struct tcps_tcpopt)) {
			TCPS_DBG("tcps_setsockopt(SET_TCPOPT): size mismatch");
			ret = -EINVAL;
			break;
		}

		if (copy_from_user(&tcpopt, user, len) != 0) {
			TCPS_DBG("tcps_setsockopt(SET_TCPOPT): copyin failed");
			ret = -EFAULT;
			break;
		}

		tcpsi= tcps_info_get(sk);
		if (!tcpsi) {
			tcpsi = tcps_info_new(sk);
			if (IS_ERR(tcpsi)) {
				ret = PTR_ERR(tcpsi);
				break;
			}
		}

		spin_lock_bh(&tcpsi->lock);
		tcpsi->suppress_tstamp = !tcpopt.tstamp_ok;
		spin_unlock_bh(&tcpsi->lock);

		tcps_info_put(sk);
		break;

	default:
		ret = -EINVAL;
	}

	return ret;
}

static int
tcps_getsockopt(struct sock *sk, int cmd, void *user, int *len)
{
	struct tcps_tcpopt tcpopt;
	struct tcp_opt *tp;
	int ret;

	ret = 0;
	switch (cmd) {
	case TCPS_SO_GET_TCPOPT:
		if (*len != sizeof(tcpopt)) {
			TCPS_DBG("tcps_getsockopt(GET_TCPOPT): size mismatch");
			ret = -EINVAL;
			break;
		}

		tp = tcp_sk(sk);
		tcpopt.tstamp_ok = tp->tstamp_ok;

		if (copy_to_user(user, &tcpopt, sizeof(tcpopt))) {
			TCPS_DBG("tcps_getsockopt(GET_TCPOPT): copyout failed");
			ret = -EFAULT;
			break;
		}
		break;
	default:
		ret = -EINVAL;
		break;
	}

	return ret;
}

static struct nf_sockopt_ops tcps_sockopts = {
	.list       = {NULL, NULL},
	.pf         = PF_INET,
	.set_optmin = TCPS_SO_BASE,
	.set_optmax = TCPS_SO_SET_MAX + 1,
	.set        = tcps_setsockopt,
	.get_optmin = TCPS_SO_BASE,
	.get_optmax = TCPS_SO_GET_MAX + 1,
	.get        = tcps_getsockopt
};

static int
tcps_control_init(void)
{
	int ret;

	ret = nf_register_sockopt(&tcps_sockopts);
	if (ret) {
		TCPS_ERR("error during registering sockopt\n");
		return ret;
	}

	return ret;
}

static void
tcps_control_fini(void)
{
	nf_unregister_sockopt(&tcps_sockopts);
}

static unsigned int
tcps_out(unsigned int hooknum,
	 struct sk_buff **skbp,
	 const struct net_device *dev_in,
	 const struct net_device *dev_out,
	 int (*okfn)(struct sk_buff *))
{
	struct sk_buff *skb = *skbp;
	struct rtable *rt;
	struct sock *sk = skb->sk;
	struct iphdr *iph;
	struct tcps_info *tcpsi = NULL;
	union tcps_tphdr h;
	struct tcps_conn *tc = NULL;
	struct tcps_conn_tcpopt *ctcopt, *rtcopt;
	u32 tcpsi_state = 0;
	int ihl, datalen;
	int ret;

	/* No need to handle skbuffs w/o socks */
	if (sk == NULL) {
		return NF_ACCEPT;
	}

	sock_hold(sk);

	tcpsi = tcps_info_get(sk);
    if (!tcpsi) {
		ret = NF_ACCEPT;
		goto out;
	} else if (!(tcpsi->flags & TCPS_INFO_F_SOURCE_FAKE)) {
		tcps_info_put(sk);
		ret = NF_ACCEPT;
		goto out;
	}

	iph = skb->nh.iph;
	ihl = iph->ihl << 2;
	datalen = skb->len - ihl;
	h.raw = (char *)iph + ihl;

	if (iph->protocol != IPPROTO_TCP) {
		ret = NF_ACCEPT;
		goto out;
	}

	if (tcps_ip_route_output(&rt, iph->daddr, 0, RT_TOS(iph->tos), 0)) {
		TCPS_DBG("tcps_ip_route_output failed\n");
		ret = NF_DROP;
		goto out;
	}

	tc = tcps_conn_out_rs_get(tcpsi->saddr, tcpsi->sport,
				  iph->saddr, h.th->source);
				  
	if (tc == NULL) {
		if (! h.th->syn) {
			ret = NF_ACCEPT;
			goto out;
		}

		tc = tcps_conn_new(iph->saddr, h.th->source,
				   iph->daddr, h.th->dest, 
				   tcpsi->saddr, tcpsi->sport);
		if (tc == NULL) {
			ret = NF_DROP;
			goto out;
		}
		TCPS_DBG("tcps_out: new entry %x:%d -> %x:%d\n",
			 ntohl(tc->caddr), ntohs(tc->cport), 
			 ntohl(tc->raddr), ntohs(tc->rport));

		tcpsi_state = TCPS_INFO_S_SYN;
	} else {
		TCPS_DBG("tcps_out: found entry %x:%d -> %x:%d\n",
			 ntohl(tc->caddr), ntohs(tc->cport), 
			 ntohl(tc->raddr), ntohs(tc->rport));

		if (h.th->syn) {
			tcpsi_state = TCPS_INFO_S_SYN;
		} else if (h.th->fin) {
			tcpsi_state = TCPS_INFO_S_FIN;
		} else if (h.th->rst) {
			tcpsi_state = TCPS_INFO_S_RST;
		} else {
			tcpsi_state = TCPS_INFO_S_SENT;
		}
	}

	spin_lock_bh(&tcpsi->lock);
	tcps_info_set_state(tcpsi, tcpsi->state|tcpsi_state);
	spin_unlock_bh(&tcpsi->lock);

	spin_lock_bh(&tc->lock);
	ctcopt = &tc->tcpopt[TCPS_CONN_TCPOPT_CL];
	rtcopt = &tc->tcpopt[TCPS_CONN_TCPOPT_RS];

	if (tcps_skb_cow(skb, rt->u.dst.dev->hard_header_len, &iph, &h.raw)) {
		ret = NF_DROP;
		goto out;
	}

	if (h.th->syn && tcpsi->suppress_tstamp) {
		tcps_timestamps_to_nop(h.th);
	}

	if ((tc->state & TCPS_CONN_S_SPLICED) == TCPS_CONN_S_SPLICED) {

		if (!before(ntohl(h.th->seq), rtcopt->csseq)) {
			int plen;
			plen = (int)(ntohl(h.th->seq) - rtcopt->sseq) + 
			    datalen - (h.th->doff << 2);
			TCPS_DBG("tcps_out: seq=%u nseq=%u\n",
			    rtcopt->csseq, rtcopt->csseq + plen);

			h.th->seq = htonl(rtcopt->csseq);
			h.th->ack_seq = htonl(rtcopt->lack);
			if (plen == 0 && h.th->fin) {
				plen++;
			}
			rtcopt->csseq += plen;
			rtcopt->sseq += plen;
		}
	}

	/* virtualserver->realserver. mangle the source to client address. */
	iph->saddr = tc->caddr;
	h.th->source = tc->cport;

	spin_unlock_bh(&tc->lock);
	tcps_conn_put(tc);
	tcps_info_put(sk);
	sock_put(sk);

	h.th->check = 0;
 	h.th->check = csum_tcpudp_magic(iph->saddr, iph->daddr,
					datalen, iph->protocol,
					csum_partial(h.raw, datalen, 0));

	skb->ip_summed = CHECKSUM_UNNECESSARY;
	ip_send_check(iph);

	(*okfn)(skb);
	return NF_STOLEN;

 out:
	if (tcpsi != NULL) {
		tcps_info_put(sk);
	}

	if (tc != NULL) {
		spin_unlock_bh(&tc->lock);
		tcps_conn_put(tc);
	}

	sock_put(sk);

	return ret;
}

static unsigned int
tcps_in(unsigned int hooknum,
	struct sk_buff **skbp,
	const struct net_device *dev_in,
	const struct net_device *dev_out,
	int (*okfn)(struct sk_buff *))
{
	struct sk_buff *skb = *skbp;
	struct rtable *rt;
	struct iphdr *iph;
	struct tcps_conn *tc;
	struct tcps_conn_tcpopt *ctcopt, *rtcopt;
	union tcps_tphdr h;
	u32 seq, ack;
	u32 tsval, tsecho, *ts;
	u32 state;
	u32 tcpsi_state = 0;
	int ihl, datalen;
	int have_ts = 0;

	iph = skb->nh.iph;
	ihl = iph->ihl << 2;
	datalen = skb->len - ihl;
	h.raw = (char *)iph + ihl;

	if (iph->protocol != IPPROTO_TCP) {
		return NF_ACCEPT;
	}

	tc = tcps_conn_in_get(iph->saddr, h.th->source,
			      iph->daddr, h.th->dest);
	if (tc == NULL) {
		return NF_ACCEPT;
	}

	spin_lock_bh(&tc->lock);

	if (h.th->syn &&
	    (tc->state & (TCPS_CONN_S_CL_FIN|TCPS_CONN_S_RS_FIN)) != 0) {
		/* The client is re-using port number... */
		spin_unlock_bh(&tc->lock);
		tcps_conn_expire_now(tc);
		__tcps_conn_put(tc);
		return NF_ACCEPT;
	}

	if (h.th->fin) {
		state = TCPS_CONN_S_CL_FIN;
		tcpsi_state = TCPS_INFO_S_FIN;
	} else if (h.th->rst) {
		state = TCPS_CONN_S_CL_RST;
		tcpsi_state = TCPS_INFO_S_RST;
	} else {
		state = TCPS_CONN_S_CL_SENT;
	}

	ctcopt = &tc->tcpopt[TCPS_CONN_TCPOPT_CL];
	rtcopt = &tc->tcpopt[TCPS_CONN_TCPOPT_RS];

	TCPS_DBG("tcps_in: found entry %x:%d -> %x:%d\n",
		 ntohl(tc->caddr), ntohs(tc->cport), 
		 ntohl(tc->raddr), ntohs(tc->rport));

	if (tcps_ip_route_output(&rt, tc->raddr, 0, RT_TOS(iph->tos), 0)) {
		TCPS_DBG("tcps_ip_route_output failed\n");
		spin_unlock_bh(&tc->lock);
		tcps_conn_put(tc);
		return NF_DROP;
	}

	dst_release(skb->dst);
	skb->dst = &rt->u.dst;

	seq = ntohl(h.th->seq);
	ack = ntohl(h.th->ack_seq);

	/* mangle the destination */
	iph->daddr = tc->raddr;
	h.th->dest = tc->rport;

	seq = rtcopt->sseq + (seq - ctcopt->rseq);
	ack = rtcopt->rseq + (ack - ctcopt->sseq);
	h.th->seq = htonl(seq);
	h.th->ack_seq = htonl(ack);
	tsval = tsecho = 0;
	if (rtcopt->ts_ok && tcps_fast_parse_timestamps(h.th, &ts)) {
		tsval = ntohl(ts[0]) - ctcopt->tsv + rtcopt->tse;
		tsecho = ntohl(ts[1]) - ctcopt->tse + rtcopt->tsv;
		ts[0] = htonl(tsval);
		ts[1] = htonl(tsecho);
		have_ts = 1;
	}

	if (!before(seq + datalen - (h.th->doff << 2), rtcopt->csseq)) {
		rtcopt->csseq = seq + datalen - (h.th->doff << 2);
		if (have_ts) {
			rtcopt->ltsv = tsval;
			rtcopt->ltse = tsecho;
		}
		TCPS_DBG("tcps_in: seq=%u nseq=%u\n", seq, rtcopt->csseq);
	}

	if (after(ack, rtcopt->lack)) {
		rtcopt->lack = ack;
	}

	if (h.th->fin | h.th->rst) {
		if (tc->csk != NULL) {
			tcps_reset_sock(tc->csk);
			sock_put(tc->csk);
			tc->csk = NULL;
		}

		if (tc->rsk != NULL) {
			struct tcps_info *tcpsi;

			if ((tcpsi = tcps_info_get(tc->rsk)) != NULL) {
				spin_lock_bh(&tcpsi->lock);
				tcps_info_set_state(tcpsi, tcpsi->state|tcpsi_state);
				spin_unlock_bh(&tcpsi->lock);
				tcps_info_put(tcpsi->sk);
			}

			if (h.th->rst) {
				tcps_reset_sock(tc->rsk);
				sock_put(tc->rsk);
				tc->rsk = NULL;
			}
		}
	}

	tcps_conn_set_state(tc, tc->state|state);
	spin_unlock_bh(&tc->lock);
	tcps_conn_put(tc);

	h.th->check = 0;
	h.th->check = csum_tcpudp_magic(iph->saddr, iph->daddr,
					datalen, iph->protocol,
					csum_partial(h.raw, datalen, 0));

	skb->ip_summed = CHECKSUM_UNNECESSARY;
	ip_send_check(iph);

	tcps_ip_send(skb);
	return NF_STOLEN;
}

static unsigned int
tcps_preroute(unsigned int hooknum,
	      struct sk_buff **skbp,
	      const struct net_device *dev_in,
	      const struct net_device *dev_out,
	      int (*okfn)(struct sk_buff *))
{
	struct sk_buff *skb = *skbp;
	struct rtable *rt;
	struct iphdr *iph;
	struct tcps_conn *tc;
	struct tcps_conn_tcpopt *ctcopt, *rtcopt;
	union tcps_tphdr h;
	int ihl, datalen;
	u32 tsval, tsecho, *ts;
	u32 seq, ack;
	u32 state;
	u32 tcpsi_state = 0;
	int for_local = 0;
	int have_ts = 0;
	int is_spliced, is_already_rst;
	int (*sndfn)(struct sk_buff *);

	iph = skb->nh.iph;
	ihl = iph->ihl << 2;
	datalen = skb->len - ihl;
	h.raw = (char *)iph + ihl;

	if (iph->protocol != IPPROTO_TCP) {
		return NF_ACCEPT;
	}

	tc = tcps_conn_preroute_get(iph->daddr, h.th->dest,
				    iph->saddr, h.th->source);
	if (tc == NULL) {
		return NF_ACCEPT;
	}
	    
	spin_lock_bh(&tc->lock);
	ctcopt = &tc->tcpopt[TCPS_CONN_TCPOPT_CL];
	rtcopt = &tc->tcpopt[TCPS_CONN_TCPOPT_RS];

	seq = ntohl(h.th->seq);
	ack = ntohl(h.th->ack_seq);

	if (h.th->fin) {
		state = TCPS_CONN_S_RS_FIN;
		tcpsi_state = TCPS_INFO_S_FIN;
	} else if (h.th->rst) {
		state = TCPS_CONN_S_RS_RST;
		tcpsi_state = TCPS_INFO_S_RST;
	} else if (h.th->syn) {
		state = TCPS_CONN_S_RS_SYN;
	} else {
		state = TCPS_CONN_S_RS_SENT;
	}

	is_spliced = tc->state & TCPS_CONN_S_SPLICED;
	if (!is_spliced) {
		for_local = 1;
	} else {
		if (h.th->ack && datalen == h.th->doff * 4 &&
		    seq == rtcopt->rseq && !after(ack, rtcopt->sseq)) {
			TCPS_DBG("tcps_preroute: found ack for local sock\n");
			for_local = 1;
		}
		/* XXX should consider piggy-backed ack case also? */
	}

	if (for_local) {
		/* route to local socket w/ destination-address mangling */
		is_already_rst = tc->state
		    & (TCPS_CONN_S_RS_FIN|TCPS_CONN_S_RS_RST|TCPS_CONN_S_CL_RST);
		if (is_spliced && is_already_rst && tc->rsk == NULL) {
			/* Already reset, drop it. */
			TCPS_DBG("socket already reset, drop it\n");
			spin_unlock_bh(&tc->lock);
			tcps_conn_put(tc);
			return NF_DROP;
		}

		iph->daddr = tc->laddr;
		h.th->dest = tc->lport;
		seq = seq + ctcopt->sseq - ctcopt->csseq;
		ack = ack + ctcopt->rseq - ctcopt->lack;
		if (tcps_fast_parse_timestamps(h.th, &ts)) {
			tsecho = ntohl(ts[1]) + ctcopt->tse - ctcopt->ltse;
			ts[1] = htonl(tsecho);
		}

		TCPS_DBG("tcps_preroute/local: seq=%u ack=%u\n", seq, ack);
		sndfn = okfn;
	} else {
		/* route to client w/ source-address mangling */
		seq = ctcopt->sseq + (seq - rtcopt->rseq);
		ack = ctcopt->rseq + (ack - rtcopt->sseq);
		tsval = tsecho = 0;
		if (ctcopt->ts_ok && tcps_fast_parse_timestamps(h.th, &ts)) {
			tsval = ntohl(ts[0]) - rtcopt->tsv + ctcopt->tse;
			tsecho = ntohl(ts[1]) - rtcopt->tse + ctcopt->tsv;
			ts[0] = htonl(tsval);
			ts[1] = htonl(tsecho);
			have_ts = 1;
		}

		if (!before(seq + datalen - (h.th->doff<<2), ctcopt->csseq)) {
			ctcopt->csseq = seq + datalen - (h.th->doff << 2);
			if (have_ts) {
				ctcopt->ltsv = tsval;
				ctcopt->ltse = tsecho;
			}
		}

		if (after(ack, ctcopt->lack)) {
			ctcopt->lack = ack;
		}

		iph->saddr = tc->vaddr;
		h.th->source = tc->vport;
		h.th->seq = htonl(seq);
		h.th->ack_seq = htonl(ack);
		TCPS_DBG("tcps_preroute: seq=%u ack=%u tsval=%u tsecr=%u\n",
			 seq, ack, tsval, tsecho);

		if (tcps_ip_route_output(&rt, iph->daddr, 0, RT_TOS(iph->tos), 0)) {
			TCPS_DBG("tcps_ip_route_output failed\n");
			spin_unlock_bh(&tc->lock);
			tcps_conn_put(tc);
			return NF_DROP;
		}
		
		dst_release(skb->dst);
		skb->dst = &rt->u.dst;
		sndfn = tcps_ip_send;
	}

	if (h.th->fin | h.th->rst) {

		if (tc->rsk != NULL) {
			struct tcps_info *tcpsi;

			if ((tcpsi = tcps_info_get(tc->rsk)) != NULL) {
				spin_lock_bh(&tcpsi->lock);
				tcps_info_set_state(tcpsi, tcpsi->state|tcpsi_state);
				spin_unlock_bh(&tcpsi->lock);
				tcps_info_put(tcpsi->sk);
			}

			tcps_reset_sock(tc->rsk);
			sock_put(tc->rsk);
			tc->rsk = NULL;
		}

		if (h.th->rst && tc->csk != NULL) {
			tcps_reset_sock(tc->csk);
			sock_put(tc->csk);
			tc->csk = NULL;
		}
	}

	tcps_conn_set_state(tc, tc->state|state);
	spin_unlock_bh(&tc->lock);
	tcps_conn_put(tc);

	h.th->check = 0;
	h.th->check =
		csum_tcpudp_magic(iph->saddr, iph->daddr,
				  datalen, iph->protocol,
				  csum_partial(h.raw, datalen, 0));

	skb->ip_summed = CHECKSUM_UNNECESSARY;
	ip_send_check(iph);

	(*sndfn)(skb);
	return NF_STOLEN;
}

static struct nf_hook_ops tcps_out_ops = {
	.list     = { NULL, NULL },
	.hook     = tcps_out,
#if LINUX_VERSION_CODE >= KERNEL_VERSION(2,5,0)
	.owner    = THIS_MODULE,
#endif
	.pf       = PF_INET,
	.hooknum  = NF_IP_LOCAL_OUT,
	.priority = 100
};

static struct nf_hook_ops tcps_in_ops = {
	.list     = { NULL, NULL },
	.hook     = tcps_in,
#if LINUX_VERSION_CODE >= KERNEL_VERSION(2,5,0)
	.owner    = THIS_MODULE,
#endif
	.pf       = PF_INET,
	.hooknum  = NF_IP_LOCAL_IN,
	.priority = 100
};

static struct nf_hook_ops tcps_preroute_ops = {
	.list     = { NULL, NULL },
	.hook     = tcps_preroute,
#if LINUX_VERSION_CODE >= KERNEL_VERSION(2,5,0)
	.owner    = THIS_MODULE,
#endif
	.pf       = PF_INET,
	.hooknum  = NF_IP_PRE_ROUTING,
	.priority = 100
};

static int __init
tcps_init(void)
{
	int ret;

	ret = tcps_info_init();
	if (ret) {
		return ret;
	}

	ret = tcps_conn_init();
	if (ret) {
		tcps_info_fini();
		return ret;
	}

	ret = tcps_control_init();
	if (ret) {
		tcps_conn_fini();
		tcps_info_fini();
		return ret;
	}

	ret = nf_register_hook(&tcps_out_ops);
	if (ret) {
		tcps_control_fini();
		tcps_conn_fini();
		tcps_info_fini();
		return ret;
	}

	ret = nf_register_hook(&tcps_in_ops);
	if (ret) {
		nf_unregister_hook(&tcps_out_ops);
		tcps_control_fini();
		tcps_conn_fini();
		tcps_info_fini();
		return ret;
	}

	ret = nf_register_hook(&tcps_preroute_ops);
	if (ret) {
		nf_unregister_hook(&tcps_in_ops);
		nf_unregister_hook(&tcps_out_ops);
		tcps_control_fini();
		tcps_conn_fini();
		tcps_info_fini();
		return ret;
	}

	TCPS_INF("tcps module loaded.\n");
	return ret;
}

module_init(tcps_init);

static void __exit
tcps_fini(void)
{
	/* XXX should flush the hash table! */
	nf_unregister_hook(&tcps_in_ops);
	nf_unregister_hook(&tcps_out_ops);
	nf_unregister_hook(&tcps_preroute_ops);

	tcps_control_fini();
	tcps_conn_fini();
	tcps_info_fini();

	TCPS_INF("tcps module unloaded.\n");
}

module_exit(tcps_fini);

MODULE_LICENSE("GPL");
