*** a/src/pl/plpgsql/src/pl_gram.y --- b/src/pl/plpgsql/src/pl_gram.y *************** *** 22,27 **** --- 22,28 ---- #include "parser/scanner.h" #include "parser/scansup.h" #include "utils/builtins.h" + #include "nodes/nodefuncs.h" /* Location tracking support --- simpler than bison's default */ *************** *** 97,103 **** static PLpgSQL_row *make_scalar_list1(char *initial_name, PLpgSQL_datum *initial_datum, int lineno, int location); static void check_sql_expr(const char *stmt, int location, ! int leaderlen); static void plpgsql_sql_error_callback(void *arg); static PLpgSQL_type *parse_datatype(const char *string, int location); static void check_labels(const char *start_label, --- 98,104 ---- PLpgSQL_datum *initial_datum, int lineno, int location); static void check_sql_expr(const char *stmt, int location, ! int leaderlen, PLpgSQL_row *check_row); static void plpgsql_sql_error_callback(void *arg); static PLpgSQL_type *parse_datatype(const char *string, int location); static void check_labels(const char *start_label, *************** *** 1408,1414 **** for_control : for_variable K_IN PLpgSQL_stmt_fori *new; /* Check first expression is well-formed */ ! check_sql_expr(expr1->query, expr1loc, 7); /* Read and check the second one */ expr2 = read_sql_expression2(K_LOOP, K_BY, --- 1409,1415 ---- PLpgSQL_stmt_fori *new; /* Check first expression is well-formed */ ! check_sql_expr(expr1->query, expr1loc, 7, NULL); /* Read and check the second one */ expr2 = read_sql_expression2(K_LOOP, K_BY, *************** *** 1470,1476 **** for_control : for_variable K_IN pfree(expr1->query); expr1->query = tmp_query; ! check_sql_expr(expr1->query, expr1loc, 0); new = palloc0(sizeof(PLpgSQL_stmt_fors)); new->cmd_type = PLPGSQL_STMT_FORS; --- 1471,1477 ---- pfree(expr1->query); expr1->query = tmp_query; ! check_sql_expr(expr1->query, expr1loc, 0, NULL); new = palloc0(sizeof(PLpgSQL_stmt_fors)); new->cmd_type = PLPGSQL_STMT_FORS; *************** *** 2562,2568 **** read_sql_construct(int until, pfree(ds.data); if (valid_sql) ! check_sql_expr(expr->query, startlocation, strlen(sqlstart)); return expr; } --- 2563,2569 ---- pfree(ds.data); if (valid_sql) ! check_sql_expr(expr->query, startlocation, strlen(sqlstart), NULL); return expr; } *************** *** 2785,2791 **** make_execsql_stmt(int firsttoken, int location) expr->ns = plpgsql_ns_top(); pfree(ds.data); ! check_sql_expr(expr->query, location, 0); execsql = palloc(sizeof(PLpgSQL_stmt_execsql)); execsql->cmd_type = PLPGSQL_STMT_EXECSQL; --- 2786,2792 ---- expr->ns = plpgsql_ns_top(); pfree(ds.data); ! check_sql_expr(expr->query, location, 0, have_into ? row : NULL); execsql = palloc(sizeof(PLpgSQL_stmt_execsql)); execsql->cmd_type = PLPGSQL_STMT_EXECSQL; *************** *** 3357,3362 **** make_scalar_list1(char *initial_name, --- 3358,3366 ---- return row; } + bool + targetlist_magic_count(Node *stmt, int *count); + /* * When the PL/pgSQL parser expects to see a SQL statement, it is very * liberal in what it accepts; for example, we often assume an *************** *** 3381,3391 **** make_scalar_list1(char *initial_name, * If no error cursor is provided, we'll just point at "location". */ static void ! check_sql_expr(const char *stmt, int location, int leaderlen) { sql_error_callback_arg cbarg; ErrorContextCallback syntax_errcontext; MemoryContext oldCxt; if (!plpgsql_check_syntax) return; --- 3385,3396 ---- * If no error cursor is provided, we'll just point at "location". */ static void ! check_sql_expr(const char *stmt, int location, int leaderlen, PLpgSQL_row *check_row) { sql_error_callback_arg cbarg; ErrorContextCallback syntax_errcontext; MemoryContext oldCxt; + List *raw_parsetree_list; if (!plpgsql_check_syntax) return; *************** *** 3399,3411 **** check_sql_expr(const char *stmt, int location, int leaderlen) error_context_stack = &syntax_errcontext; oldCxt = MemoryContextSwitchTo(compile_tmp_cxt); ! (void) raw_parser(stmt); MemoryContextSwitchTo(oldCxt); /* Restore former ereport callback */ error_context_stack = syntax_errcontext.previous; } static void plpgsql_sql_error_callback(void *arg) { --- 3404,3487 ---- error_context_stack = &syntax_errcontext; oldCxt = MemoryContextSwitchTo(compile_tmp_cxt); ! raw_parsetree_list = raw_parser(stmt); ! if (check_row != NULL) ! { ! Node *raw_parse_tree; ! int ncols; ! int fnum; ! int expected_ncols = 0; ! ! for (fnum = 0; fnum < check_row->nfields; fnum++) ! { ! if (check_row->varnos[fnum] < 0) ! continue; ! expected_ncols++; ! } ! ! raw_parse_tree = linitial(raw_parsetree_list); ! if (targetlist_magic_count(raw_parse_tree, &ncols) && ! ncols != expected_ncols) ! elog(ERROR, "expected %d, got %d", expected_ncols, ncols); ! } MemoryContextSwitchTo(oldCxt); /* Restore former ereport callback */ error_context_stack = syntax_errcontext.previous; } + static bool + find_a_star_walker(Node *node, void *context) + { + if (node == NULL) + return false; + if (IsA(node, A_Star)) + return true; + if (IsA(node, ColumnRef)) + { + ColumnRef *ref = (ColumnRef *) node; + /* A_Star can only be the last element */ + if (IsA(llast(ref->fields), A_Star)) + return true; + } + return raw_expression_tree_walker((Node *) node, + find_a_star_walker, + context); + } + + /* + * Find the number of columns in a raw statement's targetList (if SELECT) or + * returningList (if INSERT, UPDATE or DELETE). Returns false if the number of + * columns could not be determined because of an A_Star. + */ + bool + targetlist_magic_count(Node *stmt, int *count) + { + List *tlist; + + if (IsA(stmt, SelectStmt)) + tlist = ((SelectStmt *) stmt)->targetList; + else if (IsA(stmt, InsertStmt)) + tlist = ((InsertStmt *) stmt)->returningList; + else if (IsA(stmt, UpdateStmt)) + tlist = ((UpdateStmt *) stmt)->returningList; + else if (IsA(stmt, DeleteStmt)) + tlist = ((DeleteStmt *) stmt)->returningList; + else + elog(ERROR, "unknown nodeTag %d", nodeTag(stmt)); + + if (!tlist) + { + *count = 0; + return true; + } + + if (raw_expression_tree_walker((Node *) tlist, find_a_star_walker, NULL)) + return false; + *count = list_length(tlist); + return true; + } + static void plpgsql_sql_error_callback(void *arg) {