From 94b54408ecd4678f2aa855f02fc3fd7bb8716446 Mon Sep 17 00:00:00 2001
From: Thomas Krennwallner <tk+pgsql@postsubmeta.net>
Date: Sun, 20 Aug 2017 18:00:21 +0200
Subject: [PATCH] connection info refactoring

Use PQconninfoParse() in class connInfo to add support for parsing
keyword/value and URI connection strings as command-line option and
stored in column jstconnstr of table pgagent.pga_jobstep.

We additionally benefit from having an implementation for all parameter
keywords supported by libpq, and fixing a previously misspelled
connect_timeout keyword (was connection_timeout).
---
 connection.cpp       | 301 ++++++++++++++++++++-------------------------------
 include/connection.h |  41 +++----
 job.cpp              |  14 ++-
 3 files changed, 148 insertions(+), 208 deletions(-)

diff --git a/connection.cpp b/connection.cpp
index 2b96294..c491b98 100644
--- a/connection.cpp
+++ b/connection.cpp
@@ -15,37 +15,34 @@
 #include <wx/tokenzr.h>
 
 DBconn *DBconn::primaryConn = NULL;
-wxString DBconn::basicConnectString;
+connInfo *DBconn::basicConnInfo = NULL;
 static wxMutex s_PoolLock;
 
 
-DBconn::DBconn(const wxString &connectString, const wxString &db)
+DBconn::DBconn(const wxString &connStr)
+{
+	connInfo cnInfo(connStr);
+	Connect(cnInfo);
+}
+
+
+DBconn::DBconn(const connInfo &cnInfo)
+{
+	Connect(cnInfo);
+}
+
+
+bool DBconn::Connect(const connInfo &cnInfo)
 {
 	inUse = false;
 	next = 0;
 	prev = 0;
 	majorVersion = 0;
 	minorVersion = 0;
-	dbname = db;
-	connStr = connectString;
+	connStr = cnInfo.getConnectionString();
 
-	if (connectString.IsEmpty())
-	{
-		// This is a sql call to a local database.
-		// No connection string found. Use basicConnectString.
-		Connect(basicConnectString  + wxT(" dbname=") + dbname);
-	}
-	else
-	{
-		Connect(connectString);
-	}
-}
-
-
-bool DBconn::Connect(const wxString &connectString)
-{
-	LogMessage(wxString::Format(_("Creating DB connection: %s"), connectString.c_str()), LOG_DEBUG);
-	wxCharBuffer cstrUTF = connectString.mb_str(wxConvUTF8);
+	LogMessage(wxString::Format(_("Creating DB connection: '%s'"), connStr.c_str()), LOG_DEBUG);
+	wxCharBuffer cstrUTF = connStr.mb_str(wxConvUTF8);
 	conn = PQconnectdb(cstrUTF);
 	if (PQstatus(conn) != CONNECTION_OK)
 	{
@@ -53,6 +50,11 @@ bool DBconn::Connect(const wxString &connectString)
 		PQfinish(conn);
 		conn = 0;
 	}
+	else
+	{
+		// get dbname from the connection
+		dbname = wxString::FromAscii(PQdb(conn));
+	}
 	return IsValid();
 }
 
@@ -105,36 +107,48 @@ DBconn *DBconn::InitConnection(const wxString &connectString)
 {
 	wxMutexLocker lock(s_PoolLock);
 
-	basicConnectString = connectString;
-	wxString dbname;
+	if (basicConnInfo)
+	{
+		delete basicConnInfo;
+		basicConnInfo = NULL;
+	}
 
-	connInfo cnInfo = connInfo::getConnectionInfo(connectString);
-	if (cnInfo.isValid)
+	basicConnInfo = new connInfo(connectString);
+	if (basicConnInfo->IsValid())
 	{
-		dbname = cnInfo.dbname;
-		basicConnectString = cnInfo.getConnectionString();
-		primaryConn = new DBconn(cnInfo.getConnectionString(), dbname);
+		primaryConn = new DBconn(*basicConnInfo);
 
 		if (!primaryConn)
-			LogMessage(_("Failed to create primary connection!"), LOG_ERROR);
-		primaryConn->dbname = dbname;
+			LogMessage(wxString::Format("Failed to create primary connection with '%s'!", connectString.c_str()), LOG_ERROR);
 		primaryConn->inUse = true;
 	}
 	else
 	{
 		primaryConn = NULL;
-		LogMessage(wxT("Primary connection string is not valid!"), LOG_ERROR);
+		delete basicConnInfo;
+		basicConnInfo = NULL;
+		LogMessage(wxString::Format("Primary connection string '%s' is not valid: %s", connectString.c_str(), basicConnInfo->error.c_str()), LOG_ERROR);
 	}
 
 	return primaryConn;
 }
 
 
-DBconn *DBconn::Get(const wxString &connStr, const wxString &db)
+DBconn *DBconn::Get(const connInfo &cnInfo, const wxString &db)
 {
+	wxString connStr;
+
+	if (!cnInfo.IsValid())
+	{
+		LogMessage(_("Cannot allocate connection - invalid connection string specified!"), LOG_WARNING);
+		return NULL;
+	}
+
+	connStr = cnInfo.getConnectionString();
+
 	if (db.IsEmpty() && connStr.IsEmpty())
 	{
-		LogMessage(_("Cannot allocate connection - no database or connection string specified!"), LOG_WARNING);
+		LogMessage(_("Cannot allocate connection - no database and no connection string specified!"), LOG_WARNING);
 		return NULL;
 	}
 
@@ -147,7 +161,7 @@ DBconn *DBconn::Get(const wxString &connStr, const wxString &db)
 	{
 		if (thisConn && ((!db.IsEmpty() && db == thisConn->dbname && connStr.IsEmpty()) || (!connStr.IsEmpty() && connStr == thisConn->connStr)) && !thisConn->inUse)
 		{
-			LogMessage(wxString::Format(_("Allocating existing connection to database %s"), thisConn->dbname.c_str()), LOG_DEBUG);
+			LogMessage(wxString::Format(_("Allocating existing connection '%s' to database %s"), thisConn->connStr.c_str(), thisConn->dbname.c_str()), LOG_DEBUG);
 			thisConn->inUse = true;
 			return thisConn;
 		}
@@ -161,11 +175,21 @@ DBconn *DBconn::Get(const wxString &connStr, const wxString &db)
 
 	// No suitable connection was found, so create a new one.
 	DBconn *newConn = NULL;
-	newConn = new DBconn(connStr, db);
+	if (connStr.IsEmpty())
+	{
+		// This is a sql call to a local database.
+		// No connection string found. Use basicConnInfo.
+		connStr = basicConnInfo->getConnectionString() + wxT(" dbname=") + db;
+		newConn = new DBconn(connStr);
+	}
+	else
+	{
+		newConn = new DBconn(cnInfo);
+	}
 
 	if (newConn->conn)
 	{
-		LogMessage(wxString::Format(_("Allocating new connection to database %s"), newConn->dbname.c_str()), LOG_DEBUG);
+		LogMessage(wxString::Format(_("Allocating new connection '%s' to database %s"), newConn->connStr.c_str(), newConn->dbname.c_str()), LOG_DEBUG);
 		newConn->inUse = true;
 		newConn->prev = lastConn;
 		lastConn->next = newConn;
@@ -174,10 +198,10 @@ DBconn *DBconn::Get(const wxString &connStr, const wxString &db)
 	{
 		wxString warnMsg;
 		if (connStr.IsEmpty())
-			warnMsg = wxString::Format(_("Failed to create new connection to database '%s':'%s'"),
+			warnMsg = wxString::Format(_("Failed to create new connection to database '%s': %s"),
 			                           db.c_str(), newConn->GetLastError().c_str());
 		else
-			warnMsg = wxString::Format(_("Failed to create new connection for connection string '%s':%s"),
+			warnMsg = wxString::Format(_("Failed to create new connection for connection string '%s': %s"),
 			                           connStr.c_str(), newConn->GetLastError().c_str());
 		LogMessage(warnMsg, LOG_STARTUP);
 		return NULL;
@@ -195,7 +219,7 @@ void DBconn::Return()
 	this->ExecuteVoid(wxT("RESET ALL"));
 	this->lastError.Empty();
 
-	LogMessage(wxString::Format(_("Returning connection to database %s"), dbname.c_str()), LOG_DEBUG);
+	LogMessage(wxString::Format(_("Returning connection '%s' to database %s"), connStr.c_str(), dbname.c_str()), LOG_DEBUG);
 	inUse = false;
 }
 
@@ -378,172 +402,87 @@ wxString DBresult::GetString(const wxString &colname) const
 
 ///////////////////////////////////////////////////////7
 
-bool connInfo::IsValidIP()
+connInfo::connInfo(const wxString &connStr)
 {
-	if (host.IsEmpty())
-		return false;
-
-	// check for IPv4 format
-	wxStringTokenizer tkip4(host, wxT("."));
-	int count = 0;
+	isValid = false;
 
-	while (tkip4.HasMoreTokens())
-	{
-		long val = 0;
-		if (!tkip4.GetNextToken().ToLong(&val))
-			break;
-		if (count == 0 || count == 3)
-			if (val > 0 && val < 255)
-				count++;
-			else
-				break;
-		else if (val >= 0 && val < 255)
-			count++;
-		else
-			break;
-	}
-
-	if (count == 4)
-		return true;
-
-	// check for IPv6 format
-	wxStringTokenizer tkip6(host, wxT(":"));
-	count = 0;
-
-	while (tkip6.HasMoreTokens())
-	{
-		unsigned long val = 0;
-		wxString strVal = tkip6.GetNextToken();
-		if (strVal.Length() > 4 || !strVal.ToULong(&val, 16))
-			return false;
-		count++;
-	}
-	if (count <= 8)
-		return true;
-
-	// TODO:: We're not supporting mix mode (x:x:x:x:x:x:d.d.d.d)
-	//        i.e. ::ffff:12.34.56.78
-	return false;
+	Init(connStr);
 }
 
 
-wxString connInfo::getConnectionString()
+wxString connInfo::getConnectionString() const
 {
-	wxString connStr;
-
-	// Check if it has valid connection info
-	if (!isValid)
-		return connStr;
-
-	// User
-	connStr = wxT("user=") + user;
-
-	// Port
-	if (port != 0)
-	{
-		wxString portStr;
-		portStr.Printf(wxT("%ld"), port);
-		connStr += wxT(" port=") + portStr;
-	}
-
-	// host or hostaddr
-	if (!host.IsEmpty())
-	{
-		if (IsValidIP())
-			connStr += wxT(" hostaddr=") + host;
-		else
-			connStr += wxT(" host=") + host;
-	}
-
-	// connection timeout
-	if (connection_timeout != 0)
-	{
-		wxString val;
-		val.Printf(wxT("%ld"), connection_timeout);
-		connStr += wxT(" connection_timeout=") + val;
-	}
+	return connectionString;
+}
 
-	// password
-	if (!password.IsEmpty())
-		connStr += wxT(" password=") + password;
 
-	if (!dbname.IsEmpty())
-		connStr += wxT(" dbname=") + dbname;
+wxString connInfo::getError() const
+{
+	return error;
+}
 
-	LogMessage(wxString::Format(_("Connection Information:")), LOG_DEBUG);
-	LogMessage(wxString::Format(_("     user         : %s"), user.c_str()), LOG_DEBUG);
-	LogMessage(wxString::Format(_("     port         : %ld"), port), LOG_DEBUG);
-	LogMessage(wxString::Format(_("     host         : %s"), host.c_str()), LOG_DEBUG);
-	LogMessage(wxString::Format(_("     dbname       : %s"), dbname.c_str()), LOG_DEBUG);
-	LogMessage(wxString::Format(_("     password     : %s"), password.c_str()), LOG_DEBUG);
-	LogMessage(wxString::Format(_("     conn timeout : %ld"), connection_timeout), LOG_DEBUG);
 
-	return connStr;
+bool connInfo::IsValid() const
+{
+	return isValid;
 }
 
 
-connInfo connInfo::getConnectionInfo(wxString connStr)
+void connInfo::Init(const wxString &connStr)
 {
-	connInfo cnInfo;
+	PQconninfoOption *opts, *opt;
+	char *errmsg = NULL;
 
-	wxRegEx propertyExp;
-
-	// Remove the white-space(s) to match the following format
-	// i.e. prop=value
-	bool res = propertyExp.Compile(wxT("(([ ]*[\t]*)+)="));
+	// parse Keyword/Value Connection Strings and Connection URIs
+	opts = PQconninfoParse(connStr.c_str(), &errmsg);
+	if (opts == NULL)
+	{
+		if (errmsg)
+		{
+			error = wxString::FromAscii(errmsg);
+			PQfreemem(errmsg);
+		}
+		return;
+	}
 
-	propertyExp.ReplaceAll(&connStr, wxT("="));
+	isValid = true;
 
-	res = propertyExp.Compile(wxT("=(([ ]*[\t]*)+)"));
-	propertyExp.ReplaceAll(&connStr, wxT("="));
+	if (connStr.IsEmpty())
+		return;
 
-	// Seperate all the prop=value patterns
-	wxArrayString tokens = wxStringTokenize(connStr, wxT("\t \n\r"));
+	LogMessage(wxString::Format(wxString::Format(_("Parsing connection information: %s"), connStr.c_str())), LOG_DEBUG);
 
-	unsigned int index = 0;
-	while (index < tokens.Count())
+	// iterate over all options
+	for (opt = opts; opt->keyword; opt++)
 	{
-		wxString prop, value;
-
-		wxArrayString pairs = wxStringTokenize(tokens[index++], wxT("="));
-
-		if (pairs.GetCount() != 2)
-			return cnInfo;
+		if (opt->val == NULL)
+			continue;
 
-		prop = pairs[0];
-		value = pairs[1];
-
-		if (prop.CmpNoCase(wxT("user")) == 0)
-			cnInfo.user = value;
-		else if (prop.CmpNoCase(wxT("host")) == 0 || prop.CmpNoCase(wxT("hostAddr")) == 0)
-			cnInfo.host = value;
-		else if (prop.CmpNoCase(wxT("port")) == 0)
-		{
-			if (!value.ToULong(&cnInfo.port))
-				// port must be an unsigned integer
-				return cnInfo;
-		}
-		else if (prop.CmpNoCase(wxT("password")) == 0)
-			cnInfo.password = value;
-		else if (prop.CmpNoCase(wxT("connection_timeout")) == 0)
+		switch (opt->dispchar[0])
 		{
-			if (!value.ToULong(&cnInfo.connection_timeout))
-				// connection timeout must be an unsigned interger
-				return cnInfo;
+		case 'D': // debug option
+			continue;
+		case '*': // password field
+			LogMessage(wxString::Format(_("\t%s: *****"), opt->label), LOG_DEBUG);
+			break;
+		default:
+			LogMessage(wxString::Format(_("\t%s: %s"), opt->label, opt->val), LOG_DEBUG);
+			break;
 		}
-		else if (prop.CmpNoCase(wxT("dbname")) == 0)
-			cnInfo.dbname = value;
-		else
-			// Not valid property found
-			return cnInfo;
-	}
 
-	// If user, dbname & host all are blank than we will consider this an invalid connection string
-	if (cnInfo.user.IsEmpty() && cnInfo.dbname.IsEmpty() && cnInfo.host.IsEmpty())
-		cnInfo.isValid = false;
-	else
-		cnInfo.isValid = true;
+		// create plain keyword=value connection string.  used
+		// to find pooled connections in DBconn::Get() and to
+		// open the connection in DBconn::Connect. this works
+		// because PQconninfoParse() always returns the
+		// connection info options in the same order.
+		if (!connectionString.IsEmpty())
+			connectionString.Append(' ');
+		connectionString.Append(opt->keyword);
+		connectionString.Append('=');
+		connectionString.Append(opt->val);
+	}
 
-	return cnInfo;
+	PQconninfoFree(opts);
 }
 
+
diff --git a/include/connection.h b/include/connection.h
index b2693f7..008a8da 100644
--- a/include/connection.h
+++ b/include/connection.h
@@ -16,12 +16,14 @@
 #include <libpq-fe.h>
 
 class DBresult;
+class connInfo;
 
 
 class DBconn
 {
 protected:
-	DBconn(const wxString &, const wxString &);
+	DBconn(const wxString &);
+	DBconn(const connInfo &);
 	~DBconn();
 
 public:
@@ -29,17 +31,13 @@ public:
 
 	bool BackendMinimumVersion(int major, int minor);
 
-	static DBconn *Get(const wxString &connStr, const wxString &db);
+	static DBconn *Get(const connInfo &cnInfo, const wxString &db);
 	static DBconn *InitConnection(const wxString &connectString);
 
 	static void ClearConnections(bool allIncludingPrimary = false);
-	static void SetBasicConnectString(const wxString &bcs)
+	static const connInfo *GetBasicConnInfo()
 	{
-		basicConnectString = bcs;
-	}
-	static const wxString &GetBasicConnectString()
-	{
-		return basicConnectString;
+		return basicConnInfo;
 	}
 
 	wxString GetLastError();
@@ -91,12 +89,12 @@ public:
 	void Return();
 
 private:
-	bool Connect(const wxString &connectString);
+	bool Connect(const connInfo &cnInfo);
 
 	int minorVersion, majorVersion;
 
 protected:
-	static wxString basicConnectString;
+	static connInfo *basicConnInfo;
 	static DBconn *primaryConn;
 
 	wxString dbname, lastError, connStr;
@@ -149,27 +147,20 @@ protected:
 class connInfo
 {
 public:
-	connInfo()
-	{
-		isValid = false;
-		connection_timeout = 0;
-		port = 0;
-	}
+	connInfo(const wxString &connStr);
+
+	wxString getConnectionString() const;
+	wxString getError() const;
+	bool IsValid() const;
 
 private:
-	wxString      user;
-	unsigned long port;
-	wxString      host;
-	wxString      dbname;
-	unsigned long connection_timeout;
-	wxString      password;
+	wxString      connectionString;
+	wxString      error;
 	bool          isValid;
 
-	wxString getConnectionString();
-	static connInfo getConnectionInfo(wxString connStr);
+	void Init(const wxString &connStr);
 
 protected:
-	bool IsValidIP();
 	friend class DBconn;
 };
 
diff --git a/job.cpp b/job.cpp
index 2a94dd7..6d7757d 100644
--- a/job.cpp
+++ b/job.cpp
@@ -134,7 +134,16 @@ int Job::Execute()
 				wxString jstdbname = steps->GetString(wxT("jstdbname"));
 				wxString jstconnstr = steps->GetString(wxT("jstconnstr"));
 
-				stepConn = DBconn::Get(jstconnstr, jstdbname);
+				connInfo cnInfo(jstconnstr);
+				if (!cnInfo.IsValid())
+				{
+					LogMessage(wxString::Format(_("Connection string '%s' in SQL step %s (part of job %s) is invalid: %s"), jstconnstr, stepid.c_str(), jobid.c_str(), cnInfo.getError()), LOG_WARNING);
+					output = wxString::Format(_("Connection string is invalid!"));
+					succeeded = false;
+					break;
+				}
+
+				stepConn = DBconn::Get(cnInfo, jstdbname);
 				if (stepConn)
 				{
 					LogMessage(wxString::Format(_("Executing SQL step %s (part of job %s)"), stepid.c_str(), jobid.c_str()), LOG_DEBUG);
@@ -398,7 +407,8 @@ JobThread::JobThread(const wxString &jid)
 	runnable = false;
 	jobid = jid;
 
-	DBconn *threadConn = DBconn::Get(DBconn::GetBasicConnectString(), serviceDBname);
+	const connInfo *basicConnInfo = DBconn::GetBasicConnInfo();
+	DBconn *threadConn = DBconn::Get(*basicConnInfo, serviceDBname);
 	if (threadConn)
 	{
 		job = new Job(threadConn, jobid);
-- 
2.14.1

