import psycopg2
import select
import config
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError

ASYNC_OK = 1
ASYNC_READ_TIMEOUT = 2
ASYNC_WRITE_TIMEOUT = 3
ASYNC_TIMEOUT = 0.2


def wait(conn):
    while 1:
        state = conn.poll()
        if state == psycopg2.extensions.POLL_OK:
            break
        elif state == psycopg2.extensions.POLL_WRITE:
            select.select([], [conn.fileno()], [])
        elif state == psycopg2.extensions.POLL_READ:
            select.select([conn.fileno()], [], [])
        else:
            raise psycopg2.OperationalError(
                "poll() returned %s from _wait function" % state)

def wait_timeout(conn):
    while 1:
        print("Before polling the connection...")
        state = conn.poll()
        print("After polling the connection...")

        if state == psycopg2.extensions.POLL_OK:
            return ASYNC_OK
        elif state == psycopg2.extensions.POLL_WRITE:
            # Wait for the given time and then check the return status
            # If three empty lists are returned then the time-out is
            # reached.
            timeout_status = select.select(
                [], [conn.fileno()], [], ASYNC_TIMEOUT
            )
            if timeout_status == ([], [], []):
                return ASYNC_WRITE_TIMEOUT
        elif state == psycopg2.extensions.POLL_READ:
            # Wait for the given time and then check the return status
            # If three empty lists are returned then the time-out is
            # reached.
            timeout_status = select.select(
                [conn.fileno()], [], [], ASYNC_TIMEOUT
            )
            if timeout_status == ([], [], []):
                return ASYNC_READ_TIMEOUT
        else:
            raise psycopg2.OperationalError(
                "poll() returned %s from _wait_timeout function" % state
            )

tunnel_object = None
try:
    tunnel_object = SSHTunnelForwarder(
        (config.SSH_TUNNEL_HOST, int(config.SSH_TUNNEL_PORT)),
        ssh_username=config.SSH_TUNNEL_USERNAME,
        ssh_password=config.SSH_TUNNEL_PASSWORD,
        remote_bind_address=(config.DATABASE_SERVER_IP, config.DATABASE_SERVER_PORT)
    )
    tunnel_object.start()
except BaseSSHTunnelForwarderError as e:
    print("Failed to create the SSH tunnel.Error: {0}".format(str(e)))

if tunnel_object is None or not tunnel_object.is_active:
    print("Failed to create the SSH tunnel")

pg_conn = None
try:
    pg_conn = psycopg2.connect(
        host=config.SSH_TUNNEL_HOST,
        hostaddr=config.SSH_TUNNEL_HOST,
        port=tunnel_object.local_bind_port,
        database=config.DATABASE_NAME,
        user=config.DATABASE_SERVER_USER_NAME,
        password=config.DATABASE_SERVER_PASSWORD,
        sslmode='disable',
        async_=1
    )

    # If connection is asynchronous then we will have to wait
    # until the connection is ready to use.
    wait(pg_conn)
except psycopg2.Error as e:
    if e.pgerror:
        print(e.pgerror)
    elif e.diag.message_detail:
        print(e.diag.message_detail)
    else:
        print(str(e))
    quit()


if pg_conn is not None:
    cur = pg_conn.cursor()
    print("************************************")
    print ("Executing correct query SELECT version();")
    cur.execute("SELECT version()")
    res = wait_timeout(cur.connection)
    while res != ASYNC_OK:
        res = wait_timeout(cur.connection)

    print("Result:- ")
    print(cur.fetchone())
    print ("Successfully executed the CORRECT query.")
    print("************************************")

    print ("Executing WRONG query SELECT12 version();")
    cur.execute("SELECT12 version()")
    res = wait_timeout(cur.connection)
    while res != ASYNC_OK:
        res = wait_timeout(cur.connection)

    print ("Successfully executed the WRONG query.")
    print("************************************")

    quit()