From 18ce8dc53b7ae8da0a7e181caf4b2a3d3e82b49c Mon Sep 17 00:00:00 2001
From: Dmitrii Dolgov <9erthalion6@gmail.com>
Date: Wed, 8 Dec 2021 09:57:37 +0100
Subject: [PATCH] Prevent nested UNIONs in CTE with CYCLE clause

SEARCH and CYCLE implementation in 3696a600e2 mentions that nested
UNIONs are not supported, and verifies that in parse_cte to avoid
queries like "foo UNION bar UNION baz". The similar situation, when the
recursive part has a subquery with a UNION inside is not handled and
lead to a crash. Add more validation to turn the crash into
FEATURE_NOT_SUPPORTED error.
---
 src/backend/parser/parse_cte.c     | 18 ++++++++++++++++
 src/test/regress/expected/with.out | 33 ++++++++++++++++++++++++++++++
 src/test/regress/sql/with.sql      | 29 ++++++++++++++++++++++++++
 3 files changed, 80 insertions(+)

diff --git a/src/backend/parser/parse_cte.c b/src/backend/parser/parse_cte.c
index 2f51caf76c..0daa2a0d20 100644
--- a/src/backend/parser/parse_cte.c
+++ b/src/backend/parser/parse_cte.c
@@ -80,6 +80,7 @@ typedef struct CteState
 	/* working state for checkWellFormedRecursion walk only: */
 	int			selfrefcount;	/* number of self-references detected */
 	RecursionContext context;	/* context to allow or disallow self-ref */
+	bool 		union_op; 		/* CTE containst a UNION  */
 } CteState;
 
 
@@ -852,6 +853,8 @@ checkWellFormedRecursion(CteState *cstate)
 							cte->ctename),
 					 parser_errposition(cstate->pstate, cte->location)));
 
+		cstate->union_op = true;
+
 		/* The left-hand operand mustn't contain self-reference at all */
 		cstate->curitem = i;
 		cstate->innerwiths = NIL;
@@ -1116,7 +1119,22 @@ checkWellFormedSelectStmt(SelectStmt *stmt, CteState *cstate)
 		switch (stmt->op)
 		{
 			case SETOP_NONE:
+				raw_expression_tree_walker((Node *) stmt,
+										   checkWellFormedRecursionWalker,
+										   (void *) cstate);
+				break;
 			case SETOP_UNION:
+				if(cstate->union_op && stmt->rarg->withClause == NULL)
+				{
+					CommonTableExpr *cte = cstate->items[cstate->curitem].cte;
+					if (cte->cycle_clause)
+						ereport(ERROR,
+								(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+								 errmsg("with a CYCLE clause, the left side of "
+									 	"the UNION must not contain nested UNIONs"),
+						parser_errposition(cstate->pstate, cte->location)));
+				}
+
 				raw_expression_tree_walker((Node *) stmt,
 										   checkWellFormedRecursionWalker,
 										   (void *) cstate);
diff --git a/src/test/regress/expected/with.out b/src/test/regress/expected/with.out
index a3a2e383e3..217c7a8cb7 100644
--- a/src/test/regress/expected/with.out
+++ b/src/test/regress/expected/with.out
@@ -1282,6 +1282,39 @@ select * from search_graph;
 ERROR:  CYCLE types boolean and integer cannot be matched
 LINE 7: ) cycle f, t set is_cycle to true default 55 using path
                                                   ^
+with recursive x ( x ) as
+  ( select 1
+    union all
+    select x from
+    (
+      select 4 as x
+      union all
+      select x from x
+    ) as x
+  )
+cycle x set b using v
+select * from x;
+ERROR:  with a CYCLE clause, the left side of the UNION must not contain nested UNIONs
+LINE 1: with recursive x ( x ) as
+                       ^
+with recursive x ( x ) as
+  ( select 1
+    union all
+    select x from
+    (
+      select x from
+      (
+        select 4 as x
+        union all
+        select x from x
+      ) as x
+    ) as x
+  )
+cycle x set b using v
+select * from x;
+ERROR:  with a CYCLE clause, the left side of the UNION must not contain nested UNIONs
+LINE 1: with recursive x ( x ) as
+                       ^
 with recursive search_graph(f, t, label) as (
 	select * from graph g
 	union all
diff --git a/src/test/regress/sql/with.sql b/src/test/regress/sql/with.sql
index 46668a903e..9a0c9a6541 100644
--- a/src/test/regress/sql/with.sql
+++ b/src/test/regress/sql/with.sql
@@ -618,6 +618,35 @@ with recursive search_graph(f, t, label) as (
 ) cycle f, t set is_cycle to true default 55 using path
 select * from search_graph;
 
+with recursive x ( x ) as
+  ( select 1
+    union all
+    select x from
+    (
+      select 4 as x
+      union all
+      select x from x
+    ) as x
+  )
+cycle x set b using v
+select * from x;
+
+with recursive x ( x ) as
+  ( select 1
+    union all
+    select x from
+    (
+      select x from
+      (
+        select 4 as x
+        union all
+        select x from x
+      ) as x
+    ) as x
+  )
+cycle x set b using v
+select * from x;
+
 with recursive search_graph(f, t, label) as (
 	select * from graph g
 	union all
-- 
2.26.3

