From a77855103cc40b4f5d1669006c8633fda53c37db Mon Sep 17 00:00:00 2001
From: Alexander Pyhalov <a.pyhalov@postgrespro.ru>
Date: Tue, 26 Dec 2023 10:07:13 +0300
Subject: [PATCH 4/6] Compare converted whole row vars in
 search_indexed_tlist_for_non_var() correctly and fix test results

---
 src/backend/nodes/nodeFuncs.c                 | 49 +++++++++++++
 src/backend/optimizer/plan/setrefs.c          |  2 +-
 src/backend/optimizer/util/tlist.c            | 23 ++++++
 src/include/nodes/nodeFuncs.h                 |  1 +
 src/include/optimizer/tlist.h                 |  1 +
 .../regress/expected/partition_aggregate.out  | 39 ++++++----
 src/test/regress/expected/partition_join.out  | 72 +++++++++++--------
 src/test/regress/sql/partition_aggregate.sql  |  2 +-
 src/test/regress/sql/partition_join.sql       |  3 +-
 9 files changed, 144 insertions(+), 48 deletions(-)

diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c
index 4ce0230aad5..41e2acaf1bf 100644
--- a/src/backend/nodes/nodeFuncs.c
+++ b/src/backend/nodes/nodeFuncs.c
@@ -4825,3 +4825,52 @@ is_converted_whole_row_reference(Node *node)
 
 	return false;
 }
+
+/*
+ * is_equal_converted_whole_row_references
+ *		Determine if both nodes are equivalent ConvertRowtypeExprs
+ *		over the same var.
+ *		It differs from equal(), because we ignore varnullingrels.
+ */
+bool
+is_equal_converted_whole_row_references(Node *node1, Node *node2)
+{
+	ConvertRowtypeExpr *convexpr1;
+	ConvertRowtypeExpr *convexpr2;
+
+	if (!node1 || !IsA(node1, ConvertRowtypeExpr))
+		return false;
+
+	if (!node2 || !IsA(node2, ConvertRowtypeExpr))
+		return false;
+
+	convexpr1 = castNode(ConvertRowtypeExpr, node1);
+	convexpr2 = castNode(ConvertRowtypeExpr, node2);
+
+	while (convexpr1->convertformat == COERCE_IMPLICIT_CAST &&
+		 convexpr2->convertformat == COERCE_IMPLICIT_CAST &&
+		 convexpr1->resulttype == convexpr2->resulttype &&
+		 IsA(convexpr1->arg, ConvertRowtypeExpr) &&
+		 IsA(convexpr2->arg, ConvertRowtypeExpr))
+	{
+		convexpr1 = castNode(ConvertRowtypeExpr, convexpr1->arg);
+		convexpr2 = castNode(ConvertRowtypeExpr, convexpr2->arg);
+	}
+
+	if (IsA(convexpr1->arg, Var) && IsA(convexpr2->arg, Var))
+	{
+		Var		*var1 = castNode(Var, convexpr1->arg);
+		Var		*var2 = castNode(Var, convexpr2->arg);
+
+		if ((var1->varno == var2->varno) &&
+			(var1->varattno == var2->varattno) &&
+			(var1->varlevelsup == var2->varlevelsup) &&
+			(var1->vartype == var2->vartype))
+		{
+			/* TODO: Can we state that both varattnos is 0? */
+			return true;
+		}
+	}
+
+	return false;
+}
diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c
index 09314fac672..82e8872b4a9 100644
--- a/src/backend/optimizer/plan/setrefs.c
+++ b/src/backend/optimizer/plan/setrefs.c
@@ -2924,7 +2924,7 @@ search_indexed_tlist_for_non_var(Expr *node,
 	if (IsA(node, Const))
 		return NULL;
 
-	tle = tlist_member(node, itlist->tlist);
+	tle = tlist_member_match_converted_whole_row(node, itlist->tlist);
 	if (tle)
 	{
 		/* Found a matching subplan output expression */
diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c
index 7ef7f34d8b5..f3157f0f649 100644
--- a/src/backend/optimizer/util/tlist.c
+++ b/src/backend/optimizer/util/tlist.c
@@ -119,6 +119,29 @@ tlist_member_match_var(Var *var, List *targetlist)
 	return NULL;
 }
 
+/*
+ * tlist_member_match_converted_whole_row
+ * 	tlist_member() variant, which compares whole var references
+ * 	based on their varno/varattno
+ */
+TargetEntry *
+tlist_member_match_converted_whole_row(Expr *node, List *targetlist)
+{
+	ListCell   *temp;
+
+	foreach(temp, targetlist)
+	{
+		TargetEntry *tlentry = (TargetEntry *) lfirst(temp);
+
+		if (equal(node, tlentry->expr))
+			return tlentry;
+
+		if (is_equal_converted_whole_row_references((Node *)node, (Node *)tlentry->expr))
+			return tlentry;
+	}
+	return NULL;
+}
+
 /*
  * add_to_flat_tlist
  *		Add more items to a flattened tlist (if they're not already in it)
diff --git a/src/include/nodes/nodeFuncs.h b/src/include/nodes/nodeFuncs.h
index 8d30d6cbac5..cda3fff315d 100644
--- a/src/include/nodes/nodeFuncs.h
+++ b/src/include/nodes/nodeFuncs.h
@@ -220,4 +220,5 @@ extern bool planstate_tree_walker_impl(struct PlanState *planstate,
 									   void *context);
 
 extern bool is_converted_whole_row_reference(Node *node);
+extern bool is_equal_converted_whole_row_references(Node *node1, Node *node2);
 #endif							/* NODEFUNCS_H */
diff --git a/src/include/optimizer/tlist.h b/src/include/optimizer/tlist.h
index 15f8f4a4b00..9ab6adc01b8 100644
--- a/src/include/optimizer/tlist.h
+++ b/src/include/optimizer/tlist.h
@@ -18,6 +18,7 @@
 
 
 extern TargetEntry *tlist_member(Expr *node, List *targetlist);
+extern TargetEntry *tlist_member_match_converted_whole_row(Expr *node, List *targetlist);
 
 extern List *add_to_flat_tlist(List *tlist, List *exprs);
 
diff --git a/src/test/regress/expected/partition_aggregate.out b/src/test/regress/expected/partition_aggregate.out
index 5f2c0cf5786..ba9245e4011 100644
--- a/src/test/regress/expected/partition_aggregate.out
+++ b/src/test/regress/expected/partition_aggregate.out
@@ -453,27 +453,36 @@ SELECT t1.x, sum(t1.y), count(*) FROM pagg_tab1 t1, pagg_tab2 t2 WHERE t1.x = t2
  24 |  900 |   100
 (5 rows)
 
--- Check with whole-row reference; partitionwise aggregation does not apply
+-- Check with whole-row reference
 EXPLAIN (COSTS OFF)
 SELECT t1.x, sum(t1.y), count(t1) FROM pagg_tab1 t1, pagg_tab2 t2 WHERE t1.x = t2.y GROUP BY t1.x ORDER BY 1, 2, 3;
                          QUERY PLAN                          
 -------------------------------------------------------------
  Sort
    Sort Key: t1.x, (sum(t1.y)), (count(((t1.*)::pagg_tab1)))
-   ->  HashAggregate
-         Group Key: t1.x
-         ->  Hash Join
-               Hash Cond: (t1.x = t2.y)
-               ->  Append
-                     ->  Seq Scan on pagg_tab1_p1 t1_1
-                     ->  Seq Scan on pagg_tab1_p2 t1_2
-                     ->  Seq Scan on pagg_tab1_p3 t1_3
-               ->  Hash
-                     ->  Append
-                           ->  Seq Scan on pagg_tab2_p1 t2_1
-                           ->  Seq Scan on pagg_tab2_p2 t2_2
-                           ->  Seq Scan on pagg_tab2_p3 t2_3
-(15 rows)
+   ->  Append
+         ->  HashAggregate
+               Group Key: t1.x
+               ->  Hash Join
+                     Hash Cond: (t1.x = t2.y)
+                     ->  Seq Scan on pagg_tab1_p1 t1
+                     ->  Hash
+                           ->  Seq Scan on pagg_tab2_p1 t2
+         ->  HashAggregate
+               Group Key: t1_1.x
+               ->  Hash Join
+                     Hash Cond: (t1_1.x = t2_1.y)
+                     ->  Seq Scan on pagg_tab1_p2 t1_1
+                     ->  Hash
+                           ->  Seq Scan on pagg_tab2_p2 t2_1
+         ->  HashAggregate
+               Group Key: t1_2.x
+               ->  Hash Join
+                     Hash Cond: (t2_2.y = t1_2.x)
+                     ->  Seq Scan on pagg_tab2_p3 t2_2
+                     ->  Hash
+                           ->  Seq Scan on pagg_tab1_p3 t1_2
+(24 rows)
 
 SELECT t1.x, sum(t1.y), count(t1) FROM pagg_tab1 t1, pagg_tab2 t2 WHERE t1.x = t2.y GROUP BY t1.x ORDER BY 1, 2, 3;
  x  | sum  | count 
diff --git a/src/test/regress/expected/partition_join.out b/src/test/regress/expected/partition_join.out
index 53591a4f2d5..c4500290694 100644
--- a/src/test/regress/expected/partition_join.out
+++ b/src/test/regress/expected/partition_join.out
@@ -108,28 +108,33 @@ SELECT COUNT(*) FROM prt1 t1
    300
 (1 row)
 
--- left outer join, with whole-row reference; partitionwise join does not apply
+-- left outer join, with whole-row reference
 EXPLAIN (COSTS OFF)
 SELECT t1, t2 FROM prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b WHERE t1.b = 0 ORDER BY t1.a, t2.b;
                     QUERY PLAN                    
 --------------------------------------------------
  Sort
    Sort Key: t1.a, t2.b
-   ->  Hash Right Join
-         Hash Cond: (t2.b = t1.a)
-         ->  Append
+   ->  Append
+         ->  Hash Right Join
+               Hash Cond: (t2_1.b = t1_1.a)
                ->  Seq Scan on prt2_p1 t2_1
-               ->  Seq Scan on prt2_p2 t2_2
-               ->  Seq Scan on prt2_p3 t2_3
-         ->  Hash
-               ->  Append
+               ->  Hash
                      ->  Seq Scan on prt1_p1 t1_1
                            Filter: (b = 0)
+         ->  Hash Right Join
+               Hash Cond: (t2_2.b = t1_2.a)
+               ->  Seq Scan on prt2_p2 t2_2
+               ->  Hash
                      ->  Seq Scan on prt1_p2 t1_2
                            Filter: (b = 0)
+         ->  Hash Right Join
+               Hash Cond: (t2_3.b = t1_3.a)
+               ->  Seq Scan on prt2_p3 t2_3
+               ->  Hash
                      ->  Seq Scan on prt1_p3 t1_3
                            Filter: (b = 0)
-(16 rows)
+(21 rows)
 
 SELECT t1, t2 FROM prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b WHERE t1.b = 0 ORDER BY t1.a, t2.b;
       t1      |      t2      
@@ -1347,28 +1352,37 @@ SELECT t1.a, t2.b FROM (SELECT * FROM prt1 WHERE a < 450) t1 LEFT JOIN (SELECT *
 (9 rows)
 
 -- merge join when expression with whole-row reference needs to be sorted;
--- partitionwise join does not apply
 EXPLAIN (COSTS OFF)
 SELECT t1.a, t2.b FROM prt1 t1, prt2 t2 WHERE t1::text = t2::text AND t1.a = t2.b ORDER BY t1.a;
-                                       QUERY PLAN                                        
------------------------------------------------------------------------------------------
- Merge Join
-   Merge Cond: ((t1.a = t2.b) AND (((((t1.*)::prt1))::text) = ((((t2.*)::prt2))::text)))
-   ->  Sort
-         Sort Key: t1.a, ((((t1.*)::prt1))::text)
-         ->  Result
-               ->  Append
-                     ->  Seq Scan on prt1_p1 t1_1
-                     ->  Seq Scan on prt1_p2 t1_2
-                     ->  Seq Scan on prt1_p3 t1_3
-   ->  Sort
-         Sort Key: t2.b, ((((t2.*)::prt2))::text)
-         ->  Result
-               ->  Append
-                     ->  Seq Scan on prt2_p1 t2_1
-                     ->  Seq Scan on prt2_p2 t2_2
-                     ->  Seq Scan on prt2_p3 t2_3
-(16 rows)
+                                    QUERY PLAN                                     
+-----------------------------------------------------------------------------------
+ Merge Append
+   Sort Key: t1.a
+   ->  Merge Join
+         Merge Cond: ((t1_1.a = t2_1.b) AND (((t1_1.*)::text) = ((t2_1.*)::text)))
+         ->  Sort
+               Sort Key: t1_1.a, ((t1_1.*)::text)
+               ->  Seq Scan on prt1_p1 t1_1
+         ->  Sort
+               Sort Key: t2_1.b, ((t2_1.*)::text)
+               ->  Seq Scan on prt2_p1 t2_1
+   ->  Merge Join
+         Merge Cond: ((t1_2.a = t2_2.b) AND (((t1_2.*)::text) = ((t2_2.*)::text)))
+         ->  Sort
+               Sort Key: t1_2.a, ((t1_2.*)::text)
+               ->  Seq Scan on prt1_p2 t1_2
+         ->  Sort
+               Sort Key: t2_2.b, ((t2_2.*)::text)
+               ->  Seq Scan on prt2_p2 t2_2
+   ->  Merge Join
+         Merge Cond: ((t1_3.a = t2_3.b) AND (((t1_3.*)::text) = ((t2_3.*)::text)))
+         ->  Sort
+               Sort Key: t1_3.a, ((t1_3.*)::text)
+               ->  Seq Scan on prt1_p3 t1_3
+         ->  Sort
+               Sort Key: t2_3.b, ((t2_3.*)::text)
+               ->  Seq Scan on prt2_p3 t2_3
+(26 rows)
 
 SELECT t1.a, t2.b FROM prt1 t1, prt2 t2 WHERE t1::text = t2::text AND t1.a = t2.b ORDER BY t1.a;
  a  | b  
diff --git a/src/test/regress/sql/partition_aggregate.sql b/src/test/regress/sql/partition_aggregate.sql
index ab070fee244..a763228e6c0 100644
--- a/src/test/regress/sql/partition_aggregate.sql
+++ b/src/test/regress/sql/partition_aggregate.sql
@@ -116,7 +116,7 @@ EXPLAIN (COSTS OFF)
 SELECT t1.x, sum(t1.y), count(*) FROM pagg_tab1 t1, pagg_tab2 t2 WHERE t1.x = t2.y GROUP BY t1.x ORDER BY 1, 2, 3;
 SELECT t1.x, sum(t1.y), count(*) FROM pagg_tab1 t1, pagg_tab2 t2 WHERE t1.x = t2.y GROUP BY t1.x ORDER BY 1, 2, 3;
 
--- Check with whole-row reference; partitionwise aggregation does not apply
+-- Check with whole-row reference
 EXPLAIN (COSTS OFF)
 SELECT t1.x, sum(t1.y), count(t1) FROM pagg_tab1 t1, pagg_tab2 t2 WHERE t1.x = t2.y GROUP BY t1.x ORDER BY 1, 2, 3;
 SELECT t1.x, sum(t1.y), count(t1) FROM pagg_tab1 t1, pagg_tab2 t2 WHERE t1.x = t2.y GROUP BY t1.x ORDER BY 1, 2, 3;
diff --git a/src/test/regress/sql/partition_join.sql b/src/test/regress/sql/partition_join.sql
index 128ce8376e6..87135dc84d8 100644
--- a/src/test/regress/sql/partition_join.sql
+++ b/src/test/regress/sql/partition_join.sql
@@ -43,7 +43,7 @@ SELECT COUNT(*) FROM prt1 t1
   LEFT JOIN prt1 t2 ON t1.a = t2.a
   LEFT JOIN prt1 t3 ON t2.a = t3.a;
 
--- left outer join, with whole-row reference; partitionwise join does not apply
+-- left outer join, with whole-row reference
 EXPLAIN (COSTS OFF)
 SELECT t1, t2 FROM prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b WHERE t1.b = 0 ORDER BY t1.a, t2.b;
 SELECT t1, t2 FROM prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b WHERE t1.b = 0 ORDER BY t1.a, t2.b;
@@ -229,7 +229,6 @@ SELECT t1.a, t2.b FROM (SELECT * FROM prt1 WHERE a < 450) t1 LEFT JOIN (SELECT *
 SELECT t1.a, t2.b FROM (SELECT * FROM prt1 WHERE a < 450) t1 LEFT JOIN (SELECT * FROM prt2 WHERE b > 250) t2 ON t1.a = t2.b WHERE t1.b = 0 ORDER BY t1.a, t2.b;
 
 -- merge join when expression with whole-row reference needs to be sorted;
--- partitionwise join does not apply
 EXPLAIN (COSTS OFF)
 SELECT t1.a, t2.b FROM prt1 t1, prt2 t2 WHERE t1::text = t2::text AND t1.a = t2.b ORDER BY t1.a;
 SELECT t1.a, t2.b FROM prt1 t1, prt2 t2 WHERE t1::text = t2::text AND t1.a = t2.b ORDER BY t1.a;
-- 
2.34.1

