/*
 * $PostgreSQL: pgsql/contrib/citext/citext.c,v 1.2 2009/06/11 14:48:50 momjian Exp $
 */
#include "postgres.h"
#include "plpgsql.h"

#ifdef PG_MODULE_MAGIC
PG_MODULE_MAGIC;
#endif

typedef struct
{
	PLpgSQL_execstate *estate;
	PLpgSQL_function *func;
} runtime_context;

static void checker_func_beg( PLpgSQL_execstate * estate, PLpgSQL_function * func );
static bool plpgsql_stmt_tree_walker(PLpgSQL_stmt *stmt, bool (*walker)(), void *context);
static void esql_checker(PLpgSQL_execstate * estate, PLpgSQL_function * func);
static void exec_prepare_plan(PLpgSQL_execstate *estate, PLpgSQL_expr *expr, int cursorOptions);

static PLpgSQL_plugin plugin_funcs = { NULL, checker_func_beg, NULL, NULL, NULL};

void _PG_init( void )
{
    PLpgSQL_plugin ** var_ptr = (PLpgSQL_plugin **) find_rendezvous_variable( "PLpgSQL_plugin" );

    *var_ptr = &plugin_funcs;

}

static void checker_func_beg( PLpgSQL_execstate * estate, PLpgSQL_function * func )
{
	esql_checker(estate, func);
}

/*
 * call a plpgsql_stmt_tree_walker for any cell in List
 *
 */
static bool
plpgsql_stmts_walker(List *stmts, bool (*walker)(), void *context)
{
	ListCell *lc;

	foreach(lc, stmts)
	{
		PLpgSQL_stmt *stmt = (PLpgSQL_stmt *) lfirst(lc);

		if (walker(stmt, context))
			return true;
	}
	return false;
}


/*
 * plpgsql_statement_tree_walker -- walk plpgsql statement tree
 *
 */
static bool
plpgsql_stmt_tree_walker(PLpgSQL_stmt *stmt, bool (*walker)(), void *context)
{
	ListCell *l;

	if (stmt == NULL)
		return false;

	switch ((enum PLpgSQL_stmt_types) stmt->cmd_type)
	{
		case PLPGSQL_STMT_ASSIGN:
		case PLPGSQL_STMT_EXIT:
		case PLPGSQL_STMT_RETURN:
		case PLPGSQL_STMT_RETURN_NEXT:
		case PLPGSQL_STMT_RETURN_QUERY:
		case PLPGSQL_STMT_RAISE:
		case PLPGSQL_STMT_EXECSQL:
		case PLPGSQL_STMT_DYNEXECUTE:
		case PLPGSQL_STMT_GETDIAG:
		case PLPGSQL_STMT_OPEN:
		case PLPGSQL_STMT_FETCH:
		case PLPGSQL_STMT_CLOSE:
		case PLPGSQL_STMT_PERFORM:
			/* basic statements without nested statements */
			break;
		case PLPGSQL_STMT_BLOCK:
			{
				PLpgSQL_stmt_block *stmt_block = (PLpgSQL_stmt_block *) stmt;

				if (plpgsql_stmts_walker(stmt_block->body,
											    walker, context))
					return true;
				if (stmt_block->exceptions)
				{
					foreach(l, stmt_block->exceptions->exc_list)
					{
						PLpgSQL_exception *exc = (PLpgSQL_exception *) lfirst(l);

						if (plpgsql_stmts_walker(exc->action, walker, context))
							return true;
					}
				}
			}
			break;
		case PLPGSQL_STMT_IF:
			{
				PLpgSQL_stmt_if *stmt_if = (PLpgSQL_stmt_if *) stmt;

				if (plpgsql_stmts_walker(stmt_if->true_body, walker, context))
					return true;
				return plpgsql_stmts_walker(stmt_if->false_body, walker, context);
			}
			break;
		case PLPGSQL_STMT_CASE:
			{
				PLpgSQL_stmt_case *stmt_case = (PLpgSQL_stmt_case *) stmt;

				foreach(l, stmt_case->case_when_list)
				{
					if (plpgsql_stmts_walker(((PLpgSQL_case_when *) lfirst(l))->stmts,
													    walker, context))
						return true;
				}
				return plpgsql_stmts_walker(stmt_case->else_stmts, walker, context);
			}
			break;
		case PLPGSQL_STMT_LOOP:
			return plpgsql_stmts_walker(((PLpgSQL_stmt_loop *) stmt)->body,
										    walker, context);
		case PLPGSQL_STMT_WHILE:
			return plpgsql_stmts_walker(((PLpgSQL_stmt_while *) stmt)->body,
										    walker, context);
		case PLPGSQL_STMT_FORI:
			return plpgsql_stmts_walker(((PLpgSQL_stmt_fori *) stmt)->body,
										    walker, context);
		case PLPGSQL_STMT_FORS:
		case PLPGSQL_STMT_FORC:
		case PLPGSQL_STMT_DYNFORS:
			return plpgsql_stmts_walker(((PLpgSQL_stmt_forq *) stmt)->body,
										    walker, context);
		default:
			elog(ERROR, "unrecognized cmd_type: %d", stmt->cmd_type);
			break;
	}

	return false;
}

static bool
check_expr(PLpgSQL_stmt *stmt, PLpgSQL_expr *expr, int cursorOptions, runtime_context *context)
{
	const char *err_text = context->estate->err_text;

	if (expr->plan == NULL)
	{
		context->estate->err_text = NULL;
		context->estate->err_stmt = stmt;
		exec_prepare_plan(context->estate, expr, cursorOptions);
		context->estate->err_text = err_text;
	}
	return false;
}

static bool
esql_checker_walker(PLpgSQL_stmt *stmt,
							runtime_context *context)
{
	ListCell *l;

	if (stmt == NULL)
		return false;

	switch ((enum PLpgSQL_stmt_types) stmt->cmd_type)
	{
		case PLPGSQL_STMT_BLOCK:
			{
				/* check a initialization of local variables */
			}
			break;
		case PLPGSQL_STMT_ASSIGN:
			if (check_expr(stmt, ((PLpgSQL_stmt_assign *) stmt)->expr, 0, context))
				return true;
			break;
		case PLPGSQL_STMT_IF:
			if (check_expr(stmt, ((PLpgSQL_stmt_if *) stmt)->cond, 0, context))
				return true;
			break;
		case PLPGSQL_STMT_CASE:
			{
				PLpgSQL_stmt_case *stmt_case = (PLpgSQL_stmt_case *) stmt;

				if (stmt_case->t_expr != NULL && 
							    check_expr(stmt, stmt_case->t_expr, 0, context))
					return true;
				foreach(l, stmt_case->case_when_list)
				{
					PLpgSQL_case_when *cwt = (PLpgSQL_case_when *) lfirst(l);

					if (check_expr(stmt, cwt->expr, 0, context))
						return true;
				}
			}
			break;
		case PLPGSQL_STMT_LOOP:
		case PLPGSQL_STMT_WHILE:
			break;
		case PLPGSQL_STMT_FORI:
			{
				PLpgSQL_stmt_fori *stmt_fori = (PLpgSQL_stmt_fori *) stmt;

				if (check_expr(stmt, stmt_fori->lower, 0, context))
					return true;
				if (check_expr(stmt, stmt_fori->upper, 0, context))
					return true;
				if (stmt_fori->step != NULL && check_expr(stmt, stmt_fori->step, 0, context))
					return true;
			}
			break;
		case PLPGSQL_STMT_FORS:
			if (check_expr(stmt, ((PLpgSQL_stmt_fors *) stmt)->query, 0, context))
				return true;
			break;
		case PLPGSQL_STMT_FORC:
			{
				PLpgSQL_stmt_forc *stmt_forc = (PLpgSQL_stmt_forc *) stmt;
				PLpgSQL_var *curvar;
				curvar = (PLpgSQL_var *) (context->estate->datums[stmt_forc->curvar]);

				if (stmt_forc->argquery != NULL &&
						check_expr(stmt, stmt_forc->argquery, 
								    curvar->cursor_options, context))
					return true;
			}
			break;
		case PLPGSQL_STMT_OPEN:
			{
				PLpgSQL_stmt_open *stmt_open = (PLpgSQL_stmt_open *) stmt;
				PLpgSQL_var *curvar;

				curvar = (PLpgSQL_var *) (context->estate->datums[stmt_open->curvar]);

				if (stmt_open->query != NULL &&
						check_expr(stmt, stmt_open->query,
									    stmt_open->cursor_options, context))
				{
					return true;
				}
				else if (stmt_open->dynquery != NULL)
				{
					if (check_expr(stmt, stmt_open->dynquery, 0, context))
						return true;
					foreach(l, stmt_open->params)
					{
						if (check_expr(stmt, (PLpgSQL_expr *) lfirst(l), 0, context))
							return true;
					}
				}
				else
				{
					PLpgSQL_var *curvar;

					curvar = (PLpgSQL_var *) (context->estate->datums[stmt_open->curvar]);
					if (stmt_open->argquery != NULL &&
							check_expr(stmt, stmt_open->argquery, 0, context))
						return true;
					if (curvar->cursor_explicit_expr != NULL &&
							check_expr(stmt, curvar->cursor_explicit_expr,
										    curvar->cursor_options, context))
						return true;
				}
			}
			break;
		case PLPGSQL_STMT_FETCH:
		case PLPGSQL_STMT_CLOSE:
			break;
		case PLPGSQL_STMT_PERFORM:
			if (check_expr(stmt, ((PLpgSQL_stmt_perform *) stmt)->expr, 0, context))
				return true;
			break;
		case PLPGSQL_STMT_EXIT:
			{
				PLpgSQL_stmt_exit *stmt_exit = (PLpgSQL_stmt_exit *) stmt;

				if (stmt_exit->cond != NULL && check_expr(stmt, stmt_exit->cond, 0, context))
					return true;
			}
			break;
		case PLPGSQL_STMT_RETURN:
			{
				PLpgSQL_stmt_return *stmt_return = (PLpgSQL_stmt_return *) stmt;

				if (stmt_return->expr != NULL && 
						check_expr(stmt, stmt_return->expr, 0, context))
					return true;
			}
			break;
		case PLPGSQL_STMT_RETURN_NEXT:
			{
				PLpgSQL_stmt_return_next *stmt_rn = (PLpgSQL_stmt_return_next *) stmt;

				if (stmt_rn->expr != NULL)
					return check_expr(stmt, stmt_rn->expr, 0, context);
			}
			break;
		case PLPGSQL_STMT_RETURN_QUERY:
			{
				PLpgSQL_stmt_return_query *stmt_rq = (PLpgSQL_stmt_return_query *) stmt;

				if (stmt_rq->query != NULL)
				{
					if (check_expr(stmt, stmt_rq->query, 0, context))
						return true;
				}
				else
				{
					if (check_expr(stmt, stmt_rq->dynquery, 0, context))
						return true;
					foreach(l, stmt_rq->params)
					{
						if (check_expr(stmt, (PLpgSQL_expr *) lfirst(l), 0, context))
							return true;
					}
				}
			}
			break;
		case PLPGSQL_STMT_RAISE:
			{
				PLpgSQL_stmt_raise *stmt_raise = (PLpgSQL_stmt_raise *) stmt;

				foreach(l, stmt_raise->params)
				{
					if (check_expr(stmt, (PLpgSQL_expr *) lfirst(l), 0, context))
						return true;
				}
				foreach(l, stmt_raise->options)
				{
					if (check_expr(stmt, (PLpgSQL_expr *) lfirst(l), 0, context))
						return true;
				}
			}
			break;
		case PLPGSQL_STMT_EXECSQL:
			if (check_expr(stmt, ((PLpgSQL_stmt_execsql *) stmt)->sqlstmt, 0, context))
				return true;
			break;
		case PLPGSQL_STMT_DYNEXECUTE:
			{
				PLpgSQL_stmt_dynexecute *stmt_dynexecute = (PLpgSQL_stmt_dynexecute *) stmt;

				if (check_expr(stmt, stmt_dynexecute->query, 0, context))
					return true;
				foreach(l, stmt_dynexecute->params)
				{
					if (check_expr(stmt, (PLpgSQL_expr *) lfirst(l), 0, context))
						return true;
				}
			}
			break;
		case PLPGSQL_STMT_DYNFORS:
			{
				PLpgSQL_stmt_dynfors * stmt_dynfors = (PLpgSQL_stmt_dynfors *) stmt;

				if (check_expr(stmt, stmt_dynfors->query, 0, context))
					return true;
				foreach(l, stmt_dynfors->params)
				{
					if (check_expr(stmt, (PLpgSQL_expr *) lfirst(l), 0, context))
						return true;
				}
			}
			break;
		case PLPGSQL_STMT_GETDIAG:
			break;
	}

	return plpgsql_stmt_tree_walker(stmt, esql_checker_walker, context);
}

static void
esql_checker(PLpgSQL_execstate * estate, PLpgSQL_function * func)
{
	runtime_context context;

	context.estate = estate;
	context.func = func;

	esql_checker_walker((PLpgSQL_stmt *) func->action, &context);
}

/*
 * following code is taken from pl_exec.c
 *
 *
 */

/* ----------
 * Generate a prepared plan
 * ----------
 */
static void
exec_prepare_plan(PLpgSQL_execstate *estate,
				  PLpgSQL_expr *expr, int cursorOptions)
{
	SPIPlanPtr	plan;

	/*
	 * The grammar can't conveniently set expr->func while building the parse
	 * tree, so make sure it's set before parser hooks need it.
	 */
	expr->func = estate->func;

	/*
	 * Generate and save the plan
	 */
	plan = SPI_prepare_params(expr->query,
							  (ParserSetupHook) plpgsql_parser_setup,
							  (void *) expr,
							  cursorOptions);
	if (plan == NULL)
	{
		/* Some SPI errors deserve specific error messages */
		switch (SPI_result)
		{
			case SPI_ERROR_COPY:
				ereport(ERROR,
						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
						 errmsg("cannot COPY to/from client in PL/pgSQL")));
			case SPI_ERROR_TRANSACTION:
				ereport(ERROR,
						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
						 errmsg("cannot begin/end transactions in PL/pgSQL"),
						 errhint("Use a BEGIN block with an EXCEPTION clause instead.")));
			default:
				elog(ERROR, "SPI_prepare_params failed for \"%s\": %s",
					 expr->query, SPI_result_code_string(SPI_result));
		}
	}
	expr->plan = SPI_saveplan(plan);
	SPI_freeplan(plan);
}
