package daruma.sql.test;

// XXX: should not depend on MySQL
import daruma.sql.DatabaseConnection;
import daruma.sql.DatabaseConnectionFactory;
import daruma.sql.DatabaseConnectionException;
import daruma.sql.MySQLDatabaseConnectionFactory;
import daruma.sql.DatabaseConnection.QueryResult;
import daruma.global_switch.ImplementationSwitches;
import daruma.util.FatalException;

import java.sql.ResultSet;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Blob;
import javax.sql.rowset.serial.SerialBlob;

import org.junit.Test;
import org.junit.Before;
import org.junit.After;
import junit.framework.TestCase;


import java.util.List;


public class DatabaseConnectionTest extends TestCase
{
	static	final	String	TEST_DATABASE_NAME = "unit_test";
	static	final	String	TEST_USER          = "tester";
	static	final	String	TEST_PASSWD        = "";


	private	DatabaseConnection	c = null;
	private	DatabaseConnection	cNotConnected = null;

	@Before
	protected	void	setUp() throws DatabaseConnectionException,
					       FatalException
	{
		ImplementationSwitches.createInstance();

		DatabaseConnectionFactory	factory;

		try
		{
			factory = new MySQLDatabaseConnectionFactory();
		}
		catch( DatabaseConnectionException
		       .ClassLoadFailedDatabaseConnectionException  e )
		{
			fail( e.getMessage() );
			throw e;
		}

		this.c = factory.create();
		assertNotNull( this.c );

		this.cNotConnected = factory.create();
		assertNotNull( this.cNotConnected );

		try
		{
			c.connect( TEST_DATABASE_NAME ,
				   TEST_USER ,
				   TEST_PASSWD );
		}
		catch( DatabaseConnectionException  e )
		{
			fail( e.getMessage() );
			throw e;
		}
	}

	@After
	protected	void	tearDown() throws DatabaseConnectionException
	{
		assertNotNull( this.c );
		assertNotNull( this.cNotConnected );

		try
		{
			this.c.close();
			this.cNotConnected.close();
		}
		catch( DatabaseConnectionException  e )
		{
			throw e;
		}
	}


	private	void	dropAllTables() throws DatabaseConnectionException
	{
		try
		{
			this.cNotConnected.dropAllTables();

			fail();
		}
		catch( DatabaseConnectionException  e )
		{
			// OK, expected
			assertNotNull( e );
		}


		try
		{
			this.c.dropAllTables();
			assertTrue( true );
		}
		catch( DatabaseConnectionException  e )
		{
			throw e;
		}
	}


	@Test
	public	void	testGetCurrentTime()
				throws DatabaseConnectionException
	{
		try
		{
			// java.util.Date t =
			this.cNotConnected.getCurrentTime();

			fail();
		}
		catch( DatabaseConnectionException  e )
		{
			// OK, expected
			assertNotNull( e );
		}


		try
		{
			java.util.Date	t = this.c.getCurrentTime();
			assertNotNull( t );

			java.util.Date	systemCurrentTime;
			systemCurrentTime = new java.util.Date
						( System.currentTimeMillis() );

			assertTrue( t.before( systemCurrentTime ) );

			assertTrue( t.getTime() - systemCurrentTime.getTime()
				    <= 10 * 1000 /* 10 seconds */ );
		}
		catch( DatabaseConnectionException  e )
		{
			throw e;
		}
	}


	@Test
	public	void	testPreparedStatement()
				throws DatabaseConnectionException ,
				       SQLException
	{
		final String tableName = "test_blob_encoding";
		final int intValue = 123;
		final String stringValue = "abc\"'\ndef";

		this.c.dropTableIfExists( tableName );

		this.c.executeUpdate( "CREATE TABLE " + tableName
				      + " (count INTEGER, str VARCHAR(10))" );


		PreparedStatement	st;
		st = this.c.prepareStatement
			    ( "INSERT INTO " + tableName
			      + " (count, str) VALUES (?, ?)" );

		st.setInt( 1 , intValue );
		st.setString( 2 , stringValue );

		c.executeUpdate( st );


		QueryResult	r = this.c.executeQuery
				     ( "SELECT str,count FROM " + tableName
				       + " WHERE count=" + intValue );

		int	resultCount = 0;

		while( r.next() )
		{
			String	dbStringValue = r.getString(/*column*/ 1);
			int	dbIntValue    = r.getInt(/*column*/ 2);

			assertEquals( intValue    , dbIntValue );
			assertEquals( stringValue , dbStringValue );

			++ resultCount;
		}

		assertEquals( 1 , resultCount );
	}


	@Test
	public	void	testStringEncoding()
				throws DatabaseConnectionException ,
				       SQLException
	{
		this.dropAllTables();

		try
		{
			this.c.startTransaction();

			this.c.executeUpdate( "CREATE TABLE "
					      + "test_blob_encoding "
					      + "(data blob,"
					      + " count integer)" );

			int	count = 0;

			this.checkBlob( "abc".getBytes() , count ++ );
			this.checkBlob( "abc\"def".getBytes() , count ++ );
			this.checkBlob( "abc'def".getBytes() , count ++ );
			this.checkBlob( "abc\0def".getBytes() , count ++ );
			this.checkBlob( "abc\\def".getBytes() , count ++ );
			this.checkBlob( "abc\rdef".getBytes() , count ++ );
			this.checkBlob( "abc\ndef".getBytes() , count ++ );
			this.checkBlob( "abc\r\ndef".getBytes() , count ++ );
			this.checkBlob( ("abc\0x1a" + "def").getBytes() ,
					count ++ );

//			this.checkBlob( "".getBytes() , count ++ );
			this.checkBlob( "a".getBytes() , count ++ );
			this.checkBlob( "\\".getBytes() , count ++ );


			byte[]	buf = new byte[(int)Byte.MAX_VALUE
					       - (int)Byte.MIN_VALUE];

			byte	ch = Byte.MIN_VALUE;

			for ( int  i = 0  ;  i < buf.length  ;  ++ i )
			{
				buf[i] = ch;

				if ( ch == Byte.MAX_VALUE )
				{
					break;
				}

				++ ch;
			}

			this.checkBlob( buf , count ++ );

			this.c.commit();
		}
		catch( DatabaseConnectionException  e )
		{
			this.c.rollback();

			throw e;
		}
	}


	private	void	checkBlob( byte[]  value ,  int  count )
					throws DatabaseConnectionException ,
					       SQLException
	{
		PreparedStatement	st;
		st = this.c.prepareStatement
			    ( "INSERT INTO test_blob_encoding"
			      + "(data, count) VALUES (?, ?)" );

		Blob	input = new SerialBlob( value );

		st.setBlob( 1 , input );
		st.setInt( 2 , count );

		c.executeUpdate( st );

		String	countString = Integer.toString( count );

		QueryResult	r = this.c.executeQuery
				     ( "SELECT data FROM test_blob_encoding"
				       + " WHERE count=" + countString );

		int	resultCount = 0;

		while( r.next() )
		{
			Blob	dbValue = r.getBlob(1); // column 1

			assertEquals( input.length() , dbValue.length() );

			assert dbValue.length() <= Integer.MAX_VALUE;

			byte[]	dbBytes = dbValue.getBytes
						  (1L , (int)dbValue.length());

			for ( int  i = 0  ;  i < value.length  ;  ++ i )
			{
				assertEquals( value[i] , dbBytes[i] );
			}


			++ resultCount;
		}

		assertEquals( 1 , resultCount );
	}



	@Test
	public	void	testShowTables() throws DatabaseConnectionException
	{
		try
		{
			this.cNotConnected.showTables();

			fail();
		}
		catch( DatabaseConnectionException  e )
		{
			// OK, expected
			assertNotNull( e );
		}


		try
		{
			List<String>	tables = this.c.showTables();

			assertNotNull( tables );
		}
		catch( DatabaseConnectionException  e )
		{
			throw e;
		}
	}


	@Test
	public	void	testGetSingleLongValue()
				    throws DatabaseConnectionException
	{
		try
		{
			this.cNotConnected.showTables();

			fail();
		}
		catch( DatabaseConnectionException  e )
		{
			// OK, expected
			assertNotNull( e );
		}


		final	String	tableName = "testGetSingleLongValue";
		final	String	columnName = "foo";

		try
		{
			this.c.executeUpdate
				( "CREATE TABLE " + tableName
				  + " (" + columnName + " integer)" );

			assertEquals( 0L ,
				      this.c.getMaxLongValueFromTable
				      ( tableName , columnName ) );

			this.c.executeUpdate
				( "INSERT INTO " + tableName
				  + " (" + columnName + ") VALUES(1)" );

			this.c.executeUpdate
				( "INSERT INTO " + tableName
				  + " (" + columnName + ") VALUES(2)" );

			this.c.executeUpdate
				( "INSERT INTO " + tableName
				  + " (" + columnName + ") VALUES(3)" );


			long	maxValue = this.c.getMaxLongValueFromTable
						  ( tableName , columnName );

			assertEquals( 3L , maxValue );
		}
		catch( DatabaseConnectionException  e )
		{
			throw e;
		}
	}
}
