#!/usr/bin/env python

import sys, socket, struct

# V3 primitives

def send_startup(conn, db, user):
    paramValues = 'user\0%s\0database\0%s\0\0' % (user, db)
    msgLen = 8 + len(paramValues)
    msg = struct.pack('>II', msgLen, 196608) + paramValues
    conn.send(msg)

def send(conn, type, data):
    msg = type + struct.pack('>i', len(data)+4) + data
    conn.send(msg)

def send_parse(conn, statement, query, oids):
    sys.stdout.write('<= Parse(%s,%s,%s)\n' % (repr(statement), repr(query), repr(oids)))
    msg = statement + '\0'
    msg += query + '\0'
    msg += struct.pack('>H', len(oids))
    for oid in oids:
        msg += struct.pack('>I', oid)
    send(conn, 'P', msg)

def send_bind(conn, portal, statement, params):
    sys.stdout.write('<= Bind(%s,%s,%s)\n' % (repr(portal), repr(statement), repr(params)))
    msg = portal + '\0'
    msg += statement + '\0'
    msg += struct.pack('>H', 0)           # param format count
    msg += struct.pack('>H', len(params)) # param count
    for param in params:
        if param is None:
            msg += struct.pack('>i', -1)  # NULL
        else:
            msg += struct.pack('>i', len(param) + 1) # param data length
            msg += param + '\0'
    msg += struct.pack('>H', 0)           # result format count
    send(conn, 'B', msg)

def send_execute(conn, portal, count):
    sys.stdout.write('<= Execute(%s,%d)\n' % (repr(portal), count))
    msg = portal + '\0'
    msg += struct.pack('>i', 0)  # max rows
    send(conn, 'E', msg)

def send_sync(conn):
    sys.stdout.write('<= Sync\n')
    send(conn, 'S', '')

#
# Connection setup.
#

def connect(host, port, db, user):
    conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
    conn.connect( (host, port) )
    send_startup(conn, db, user)

    type = conn.recv(1)
    if type != 'R': raise RuntimeError('bad startup type: ' + type)
    
    length, = struct.unpack('>i', conn.recv(4))
    data = conn.recv(length-4)
    res, = struct.unpack('>i', data[:4])
    
    if res != 0: raise RuntimeError("auth was needed but we don't do it")

    process_results(conn)
    return conn

#
# Normal connection processing
#

def make_log(type):
    return lambda conn, data: sys.stdout.write('  => ' + type + '\n')

def make_server_response(type):
    def server_response(conn, data, type=type):
        sys.stdout.write('  => ' + type + '\n')
        i = 0
        while i != -1 and data[i] != '\0':
            type = data[i]
            end = data.find('\0', i+1)
            if end == -1:
                value = data[i+1:]
                i = -1
            else:
                value = data[i+1:end]
                i = end + 1
                
            sys.stdout.write('  =>  ' + type + ': ' + value + '\n')
            
    return server_response

def command_complete(conn, data):
    sys.stdout.write('  => CommandComplete: ' + data + '\n')

def data_row(conn, data):
    sys.stdout.write('  => DataRow ')
    count, = struct.unpack('>H', data[:2])
    o = 2
    for i in xrange(count):
        length, = struct.unpack('>i', data[o:o+4])
        o += 4
        if length == -1: sys.stdout.write('NULL,')
        else:
            sys.stdout.write(data[o:o+length] + ',')
            o += length
    sys.stdout.write('\n')

handlers = {
    'K': make_log('BackendKeyData'),
    '2': make_log('BindComplete'),
    '3': make_log('CloseComplete'),
    'C': command_complete,
    'D': data_row,
    'I': make_log('EmptyQuery'),
    'E': make_server_response('ErrorResponse'),
    'N': make_server_response('NoticeResponse'),
    'S': make_log('ParameterStatus'),
    '1': make_log('ParseComplete'),
    'Z': make_log('ReadyForQuery')
    }
        
def process_results(conn):
    seen_sync = False
    type = None
    while type != 'Z':
        type = conn.recv(1)
        if not type:
            raise RuntimeError('EOF seen')

        length, = struct.unpack('>i', conn.recv(4))
        data = conn.recv(length-4)

        if not handlers.has_key(type):
            raise RuntimeError('Unhandled message type ' + type)
        
        handlers[type](conn,data)

statement_number = 1
def combos(conn, query, oids, params, params_2):
    global statement_number
    
    send_parse(conn, '', query, oids)
    send_parse(conn, '', query, oids)
    send_bind(conn, '', '', params)
    send_execute(conn, '', 0)
    send_sync(conn)
    process_results(conn)

    send_bind(conn, '', '', params_2)
    send_execute(conn, '', 0)
    send_sync(conn)
    process_results(conn)

    stmt = 's_%d' % statement_number
    statement_number += 1
    send_parse(conn, stmt, query, oids)
    send_bind(conn, '', stmt, params)
    send_execute(conn, '', 0)
    send_sync(conn)
    process_results(conn)

    send_bind(conn, '', stmt, params_2)
    send_execute(conn, '', 0)
    send_sync(conn)
    process_results(conn)
    
def tests(conn):
    # Empty query
    combos(conn=conn, query='', oids=(), params=(), params_2=())
    
    # Simple SELECT
    combos(conn=conn, query='SELECT 1', oids=(), params=(), params_2=())

    # Simple parameterized SELECT
    combos(conn=conn, query='SELECT $1', oids=(23,),
           params=('42',), params_2=('24',))

    # Parameterized SELECT that calls a function that can be preevaluated
    combos(conn=conn, query='SELECT abs($1)', oids=(23,),
           params=('42',), params_2=('-24',))

    # Parameterized SELECT that calls a function that can't be preevaluated
    combos(conn=conn, query='SELECT abs($1 + random())', oids=(23,),
           params=('42',), params_2=('-24',))

if __name__ == '__main__':
    host, port, db, user = sys.argv[1:]
    conn = connect(host, int(port), db, user)
    tests(conn)
    send(conn, 'X', '')
    conn.close()
    
