diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index 759566a..3b9715b 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -100,6 +100,7 @@ typedef struct bool allow_restricted; } has_parallel_hazard_arg; +static List *get_opargs(const Expr *clause); static bool aggregates_allow_partial_walker(Node *node, partial_agg_context *context); static bool contain_agg_clause_walker(Node *node, void *context); @@ -197,6 +198,19 @@ make_opclause(Oid opno, Oid opresulttype, bool opretset, return (Expr *) expr; } +static List * +get_opargs(const Expr *clause) +{ + if (IsA(clause, OpExpr)) + return ((OpExpr *) clause)->args; + + if (IsA(clause, ScalarArrayOpExpr)) + return ((ScalarArrayOpExpr *) clause)->args; + + elog(ERROR, "unrecognized node type: %d", + (int) nodeTag(clause)); +} + /* * get_leftop * @@ -206,10 +220,10 @@ make_opclause(Oid opno, Oid opresulttype, bool opretset, Node * get_leftop(const Expr *clause) { - const OpExpr *expr = (const OpExpr *) clause; + const List *args = get_opargs(clause); - if (expr->args != NIL) - return linitial(expr->args); + if (args != NIL) + return linitial(args); else return NULL; } @@ -223,10 +237,10 @@ get_leftop(const Expr *clause) Node * get_rightop(const Expr *clause) { - const OpExpr *expr = (const OpExpr *) clause; + const List *args = get_opargs(clause); - if (list_length(expr->args) >= 2) - return lsecond(expr->args); + if (list_length(args) >= 2) + return lsecond(args); else return NULL; }