#!/usr/bin/env python

import psycopg2
import time

def log(msg):
   print msg

def GetConn():
   global conn
   global seq
   retryCount = 20
   retrySleep = 1 # sec
   while retryCount > 0:
      try:
         conn = psycopg2.connect("dbname='ioltas' user='ioltas' host='localhost'")
         if seq is None:
            seq = PgSeq(seqName)
         return
      except psycopg2.OperationalError as e:
         log('Error connecting due to: %s, retrying' % (e))
         retryCount -= 1
         time.sleep(retrySleep)
      except psycopg2.DatabaseError as e:
         log('Error connecting due to: %s, retrying' % (e))
         retryCount -= 1
         time.sleep(retrySleep)
   assert False, 'GetConn retry failed'

def ExecSql(sql, returnOutput=True):
   global conn
   retryCount = 5
   while retryCount > 0:
      try:
         with conn.cursor() as cur:
            cur.execute(sql)
            if returnOutput:
               result = cur.fetchall()
            else:
               result = None
            # Make changes persistent
            # Not committing here enforces everything to be part
            # of the same transaction for the same connection.
            # conn.commit()
            return result
      except psycopg2.OperationalError as e:
         retryCount -= 1
         log('Error executing %s due to: %s, retrying' % (sql, e))
         GetConn()
      except psycopg2.DatabaseError as e:
         assert 'duplicate key value' not in str(e), str(e)
         retryCount -= 1
         log('Error executing %s due to: %s, retrying' % (sql, e))
         GetConn()
   assert False, 'ExecSql retry failed for %s' % sql

class PgSeq(object):
   def __init__(self, name):
      self.name = name
      self.currentKey = -1

   def GetNext(self):
      prevKey = self.currentKey
      self.currentKey = ExecSql("SELECT NEXTVAL('%s');" % self.name)[0][0]
      assert prevKey <= self.currentKey, 'DB returned reused sequence '\
             'value prev=%d curr=%d' % (prevKey, self.currentKey)
      log('Fetched %s. prev=%d nextval=%d' % \
          (self.name, prevKey, self.currentKey))
      return self.currentKey

seqName = 'testseq'
tblName = 'tbl'
schemaSql = ['DROP SEQUENCE IF EXISTS %s;' % seqName,
             'CREATE SEQUENCE %s INCREMENT BY 1;' % seqName,
             'DROP TABLE IF EXISTS %s;' % tblName,
             'CREATE TABLE %s (id NUMERIC(38) PRIMARY KEY);' % tblName,
             'COMMIT;',
            ]
conn = None
seq = None

def InitSchema():
   for sql in schemaSql:
      ExecSql(sql, False)

def RunWorkload():
   while True:
      seqVal = seq.GetNext()
      ExecSql('INSERT INTO %s VALUES(%d);' % (tblName, seqVal), False)
      log('Inserted %d into %s' % (seqVal, tblName))

def main():
   GetConn()
   InitSchema()
   RunWorkload()

if __name__ == "__main__":
   main()
