// -*- tab-width: 4; indent-tabs-mode: t -*-
package org.postgresql.test.jdbc2;

import org.postgresql.test.TestUtil;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import java.sql.*;
import java.util.Map;

/*
 * Test case for getArray() and setArray()
 */
public class ArrayTest extends TestCase {
	private Connection con;
	private Statement stmt;

	private final boolean useSP;
	private final String name;
	private final Object testIn;
	private final Object testOut;
	private final int type;
	private final String sqlTypeName;

	//
	// Helper implementation of java.sql.Array.
	//
	private class WrappedArray implements Array {
		WrappedArray(Object wrappedArray, int sqlType) {
			if (wrappedArray == null)
				throw new IllegalArgumentException("wrappedArray must not be null");
			this.wrappedArray = wrappedArray;
			this.sqlType = sqlType;
		}
		
		public Object getArray() { return wrappedArray; }
		public Object getArray(Map map) { throw new UnsupportedOperationException(); }
		public Object getArray(long index, int count) { throw new UnsupportedOperationException(); }
		public Object getArray(long index, int count, Map map) { throw new UnsupportedOperationException(); }
		public int getBaseType() { return sqlType; }
		public String getBaseTypeName() { throw new UnsupportedOperationException(); }
		public ResultSet getResultSet() { throw new UnsupportedOperationException(); }
		public ResultSet getResultSet(long index, int count) { throw new UnsupportedOperationException(); }
		public ResultSet getResultSet(long index, int count, Map map) { throw new UnsupportedOperationException(); }
		public ResultSet getResultSet(Map map) { throw new UnsupportedOperationException(); }
		
		private final Object wrappedArray;
		private final int sqlType;
	}
 	
	private static void assertArrayEquals(String msg, Object a1, Object a2) {
		if (a1 == a2)
			return;

		if (a1 == null) {
			assertNull(msg, a1);
			return;
		}

		if (a2 == null) {
			assertNull(msg, a2);
			return;
		}
			
		int length1 = java.lang.reflect.Array.getLength(a1);
		int length2 = java.lang.reflect.Array.getLength(a2);
		assertEquals(msg + " arrays have differing lengths", length1, length2);

		for (int i = 0; i < length1; ++i) {
			Object o1 = java.lang.reflect.Array.get(a1, i);
			Object o2 = java.lang.reflect.Array.get(a2, i);
			assertEquals(msg + " element " + i + ":", o1, o2);
		}
	}

	ArrayTest(String name, boolean useSP,
			  Object testIn, Object testOut, int type, String sqlTypeName)
	{
		super(name + (useSP ? "" : "; serverPrepare"));
		this.useSP = useSP;
		this.name = name;
		this.testIn = testIn;
		this.testOut = (testOut == null ? testIn : testOut);
		this.type = type;
		this.sqlTypeName = sqlTypeName;
	}
 	
	protected void setUp() throws Exception {
		con = TestUtil.openDB();
		con.setAutoCommit(true);
		stmt = con.createStatement();
		TestUtil.createTable(con, "testarray", "data " + sqlTypeName);
	}
 	
	protected void tearDown() throws Exception {
		if (con != null) {
			TestUtil.dropTable(con, "testarray");
 			TestUtil.closeDB(con);
		}
	}
 	
	public void runTest() throws SQLException {
		PreparedStatement prepareInsert = con.prepareStatement("INSERT INTO testarray(data) VALUES(?)");
		((org.postgresql.PGStatement)prepareInsert).setUseServerPrepare(useSP);
		prepareInsert.setArray(1, new WrappedArray(testIn, type));
		prepareInsert.executeUpdate();

		ResultSet result = stmt.executeQuery("SELECT data FROM testarray");
		assertTrue("result row expected", result.next());
		assertArrayEquals("result array mismatch", testOut != null ? testOut : testIn, result.getArray(1).getArray());
		result.close();

		prepareInsert.close();
	}

	private static void addTests(TestSuite suite, String name, Object testIn, Object testOut, int type, String sqlTypeName) {
		suite.addTest(new ArrayTest(name, false, testIn, testOut, type, sqlTypeName));
		suite.addTest(new ArrayTest(name, true, testIn, testOut, type, sqlTypeName));
	}
	
	public static Test suite() {
		TestSuite ts = new TestSuite();

		//
		// integers
		//

		addTests(ts, "empty Object[] -> int[]",
				 new Object[0], null, Types.INTEGER,
				 "integer[]");

		addTests(ts, "Integer[] -> int[]",
				 new Integer[] { new Integer(1), new Integer(2) }, null, Types.INTEGER,
				 "integer[]");

		addTests(ts, "String[] -> int[]", 
				 new String[] { "1", "2" }, new int[] { 1,2 }, Types.INTEGER,
				 "integer[]");

		addTests(ts, "int[] -> int[]", 
				 new int[] { 1,2 }, null, Types.INTEGER,
				 "integer[]");

		//
		// strings
		//

		addTests(ts, "empty Object[] -> text[]", 
				 new Object[0], null, Types.VARCHAR,
				 "text[]");

		addTests(ts, "Integer[] -> text[]",
				 new Integer[] { new Integer(1), new Integer(2) }, new Object[] { "1", "2" }, Types.VARCHAR,
				 "text[]");
		
		addTests(ts, "String[] -> text[]",
				 new String[] { "abcd", "'", "\"", "\\" }, null, Types.VARCHAR,
				 "text[]");

		//
		// booleans
		//

		addTests(ts, "empty Object[] -> bool[]",
				 new Object[0], null, Types.BIT,
				 "bool[]");

		addTests(ts, "Boolean[] -> bool[]",
				 new Boolean[] { new Boolean(true), new Boolean(false) }, null, Types.BIT,
				 "bool[]");

		addTests(ts, "Integer[] -> bool[]",
				 new Integer[] { new Integer(1), new Integer(0) }, new boolean[] { true, false }, Types.BIT,
				 "bool[]");

		addTests(ts, "boolean[] -> bool[]",
				 new boolean[] { true, false }, null, Types.BIT,
				 "bool[]");

		//
		// dates
		//

		addTests(ts, "empty Object[] -> date[]",
				 new Object[0], null, Types.DATE,
				 "date[]");

		addTests(ts, "Date[] -> date[]",
				 new Date[] { new Date(252414000000L), new Date(1135940400000L) }, null, Types.DATE,
				 "date[]");

		addTests(ts, "String[] -> date[]",
				 new String[] { "1978-01-01", "2005-12-31" }, new Date[] { new Date(252414000000L), new Date(1135940400000L) }, Types.DATE,
				 "date[]");

		//
		// that's probably enough to catch most of the corner cases.
		//

		return ts;
	}
}
