From fcbb68b7d2bb9df63c92bc773240873e1e27a5a8 Mon Sep 17 00:00:00 2001
From: Nathan Bossart <nathandbossart@gmail.com>
Date: Fri, 16 Sep 2022 20:44:03 -0700
Subject: [PATCH v1 1/1] introduce pg_lfind8_idx and pg_lfind8_ge_idx

---
 src/include/port/pg_lfind.h                   |  72 +++++++++++++
 src/include/port/simd.h                       | 100 ++++++++++++++++++
 .../test_lfind/expected/test_lfind.out        |  12 +++
 .../modules/test_lfind/sql/test_lfind.sql     |   2 +
 .../modules/test_lfind/test_lfind--1.0.sql    |   8 ++
 src/test/modules/test_lfind/test_lfind.c      |  81 +++++++++++++-
 6 files changed, 274 insertions(+), 1 deletion(-)

diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h
index 0625cac6b5..34cf30e591 100644
--- a/src/include/port/pg_lfind.h
+++ b/src/include/port/pg_lfind.h
@@ -48,6 +48,42 @@ pg_lfind8(uint8 key, uint8 *base, uint32 nelem)
 	return false;
 }
 
+/*
+ * pg_lfind8_idx
+ *
+ * Return index of the first element in 'base' that equals 'key'.  Return -1 if
+ * there is no such element.
+ */
+static inline int
+pg_lfind8_idx(uint8 key, uint8 *base, int nelem)
+{
+	int			i = 0;
+
+#ifndef USE_NO_SIMD
+	/* round down to multiple of vector length */
+	int			tail_idx = nelem & ~(sizeof(Vector8) - 1);
+	Vector8		chunk;
+
+	for (; i < tail_idx; i += sizeof(Vector8))
+	{
+		int			idx;
+
+		vector8_load(&chunk, &base[i]);
+		if ((idx = vector8_find(chunk, key)) != -1)
+			return i + idx;
+	}
+#endif
+
+	/* Process the remaining elements one at a time. */
+	for (; i < nelem; i++)
+	{
+		if (key == base[i])
+			return i;
+	}
+
+	return -1;
+}
+
 /*
  * pg_lfind8_le
  *
@@ -80,6 +116,42 @@ pg_lfind8_le(uint8 key, uint8 *base, uint32 nelem)
 	return false;
 }
 
+/*
+ * pg_lfind8_ge_idx
+ *
+ * Return index of the first element in 'base' that is greater than or equal to
+ * 'key'.  Return -1 if there is no such element.
+ */
+static inline int
+pg_lfind8_ge_idx(uint8 key, uint8 *base, int nelem)
+{
+	int			i = 0;
+
+#ifndef USE_NO_SIMD
+	/* round down to multiple of vector length */
+	int			tail_idx = nelem & ~(sizeof(Vector8) - 1);
+	Vector8		chunk;
+
+	for (; i < tail_idx; i += sizeof(Vector8))
+	{
+		int			idx;
+
+		vector8_load(&chunk, &base[i]);
+		if ((idx = vector8_find_ge(chunk, key)) != -1)
+			return i + idx;
+	}
+#endif
+
+	/* Process the remaining elements one at a time. */
+	for (; i < nelem; i++)
+	{
+		if (base[i] >= key)
+			return i;
+	}
+
+	return -1;
+}
+
 /*
  * pg_lfind32
  *
diff --git a/src/include/port/simd.h b/src/include/port/simd.h
index 61ae4ecf60..e79d2ad5e4 100644
--- a/src/include/port/simd.h
+++ b/src/include/port/simd.h
@@ -60,6 +60,15 @@ typedef uint32x4_t Vector32;
 typedef uint64 Vector8;
 #endif
 
+/*
+ * Some of the functions with SIMD implementations use bitwise operations
+ * available in pg_bitutils.h.  There are currently no non-SIMD implementations
+ * that require these bitwise operations.
+ */
+#ifndef USE_NO_SIMD
+#include "port/pg_bitutils.h"
+#endif
+
 /* load/store operations */
 static inline void vector8_load(Vector8 *v, const uint8 *s);
 #ifndef USE_NO_SIMD
@@ -79,6 +88,8 @@ static inline bool vector8_has_le(const Vector8 v, const uint8 c);
 static inline bool vector8_is_highbit_set(const Vector8 v);
 #ifndef USE_NO_SIMD
 static inline bool vector32_is_highbit_set(const Vector32 v);
+static inline int vector8_find(const Vector8 v, const uint8 c);
+static inline int vector8_find_ge(const Vector8 v, const uint8 c);
 #endif
 
 /* arithmetic operations */
@@ -299,6 +310,95 @@ vector32_is_highbit_set(const Vector32 v)
 }
 #endif							/* ! USE_NO_SIMD */
 
+/*
+ * Return index of the first element in the vector that is equal to the given
+ * scalar.  Return -1 if there is no such element.
+ */
+#ifndef USE_NO_SIMD
+static inline int
+vector8_find(const Vector8 v, const uint8 c)
+{
+	Vector8		cmp;
+	int			result = -1;
+#if defined(USE_SSE2)
+	uint32		mask;
+#elif defined(USE_NEON)
+	uint64		mask;
+#endif
+
+	/* pre-compute the result for assert checking */
+#ifdef USE_ASSERT_CHECKING
+	int			assert_result = -1;
+
+	for (Size i = 0; i < sizeof(Vector8); i++)
+	{
+		if (((const uint8 *) &v)[i] == c)
+		{
+			assert_result = i;
+			break;
+		}
+	}
+#endif							/* USE_ASSERT_CHECKING */
+
+	cmp = vector8_eq(v, vector8_broadcast(c));
+
+#if defined(USE_SSE2)
+	mask = _mm_movemask_epi8(cmp);
+	if (mask)
+		result = pg_rightmost_one_pos32(mask);
+#elif defined(USE_NEON)
+	/*
+	 * Adapted from
+	 * https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
+	 */
+	mask = vget_lane_u64((uint64x1_t) vshrn_n_u16((uint16x8_t) cmp, 4), 0);
+	if (mask)
+		result = pg_rightmost_one_pos64(mask) >> 2;
+#endif
+
+	Assert(assert_result == result);
+	return result;
+}
+#endif							/* ! USE_NO_SIMD */
+
+/*
+ * Return index of the first element in the vector that is greater than or
+ * equal to the given scalar.  Return -1 is there is no such element.
+ */
+#ifndef USE_NO_SIMD
+static inline int
+vector8_find_ge(const Vector8 v, const uint8 c)
+{
+	Vector8		sub;
+	int			result;
+
+	/* pre-compute the result for assert checking */
+#ifdef USE_ASSERT_CHECKING
+	int			assert_result = -1;
+
+	for (Size i = 0; i < sizeof(Vector8); i++)
+	{
+		if (((const uint8 *) &v)[i] >= c)
+		{
+			assert_result = i;
+			break;
+		}
+	}
+#endif                          /* USE_ASSERT_CHECKING */
+
+	/*
+	 * Use saturating subtraction to find bytes >= c, which will present as
+	 * NUL bytes.  This approach is a workaround for the lack of unsigned
+	 * comparison instructions on some architectures.
+	 */
+	sub = vector8_ssub(vector8_broadcast(c), v);
+	result = vector8_find(sub, 0);
+
+	Assert(assert_result == result);
+	return result;
+}
+#endif							/* ! USE_NO_SIMD */
+
 /*
  * Return the bitwise OR of the inputs
  */
diff --git a/src/test/modules/test_lfind/expected/test_lfind.out b/src/test/modules/test_lfind/expected/test_lfind.out
index 1d4b14e703..30ecad4e9e 100644
--- a/src/test/modules/test_lfind/expected/test_lfind.out
+++ b/src/test/modules/test_lfind/expected/test_lfind.out
@@ -22,3 +22,15 @@ SELECT test_lfind32();
  
 (1 row)
 
+SELECT test_lfind8_idx();
+ test_lfind8_idx 
+-----------------
+ 
+(1 row)
+
+SELECT test_lfind8_ge_idx();
+ test_lfind8_ge_idx 
+--------------------
+ 
+(1 row)
+
diff --git a/src/test/modules/test_lfind/sql/test_lfind.sql b/src/test/modules/test_lfind/sql/test_lfind.sql
index 766c640831..0c01497aef 100644
--- a/src/test/modules/test_lfind/sql/test_lfind.sql
+++ b/src/test/modules/test_lfind/sql/test_lfind.sql
@@ -8,3 +8,5 @@ CREATE EXTENSION test_lfind;
 SELECT test_lfind8();
 SELECT test_lfind8_le();
 SELECT test_lfind32();
+SELECT test_lfind8_idx();
+SELECT test_lfind8_ge_idx();
diff --git a/src/test/modules/test_lfind/test_lfind--1.0.sql b/src/test/modules/test_lfind/test_lfind--1.0.sql
index 81801926ae..50b635794d 100644
--- a/src/test/modules/test_lfind/test_lfind--1.0.sql
+++ b/src/test/modules/test_lfind/test_lfind--1.0.sql
@@ -14,3 +14,11 @@ CREATE FUNCTION test_lfind8()
 CREATE FUNCTION test_lfind8_le()
 	RETURNS pg_catalog.void
 	AS 'MODULE_PATHNAME' LANGUAGE C;
+
+CREATE FUNCTION test_lfind8_idx()
+	RETURNS pg_catalog.void
+	AS 'MODULE_PATHNAME' LANGUAGE C;
+
+CREATE FUNCTION test_lfind8_ge_idx()
+	RETURNS pg_catalog.void
+	AS 'MODULE_PATHNAME' LANGUAGE C;
diff --git a/src/test/modules/test_lfind/test_lfind.c b/src/test/modules/test_lfind/test_lfind.c
index 82673d54c6..6aa33edb3b 100644
--- a/src/test/modules/test_lfind/test_lfind.c
+++ b/src/test/modules/test_lfind/test_lfind.c
@@ -115,11 +115,90 @@ test_lfind8_le(PG_FUNCTION_ARGS)
 	PG_RETURN_VOID();
 }
 
+/* workhorse for test_lfind8_idx */
+static void
+test_lfind8_idx_internal(uint8 key)
+{
+	uint8		charbuf[LEN_WITH_TAIL(Vector8)];
+	const int	len_no_tail = LEN_NO_TAIL(Vector8);
+	const int	len_with_tail = LEN_WITH_TAIL(Vector8);
+	int			keypos;
+
+	memset(charbuf, 0xFF, len_with_tail);
+	/* search tail to test one-byte-at-a-time path */
+	keypos = len_with_tail - 1;
+	charbuf[keypos] = key;
+	if (key > 0x00 && (pg_lfind8_idx(key - 1, charbuf, len_with_tail) != -1))
+		elog(ERROR, "pg_lfind8_idx() found nonexistent element '0x%x'", key - 1);
+	if (key < 0xFF && (pg_lfind8_idx(key, charbuf, len_with_tail) != keypos))
+		elog(ERROR, "pg_lfind8_idx() did not find existing element '0x%x'", key);
+	if (key < 0xFE && (pg_lfind8_idx(key + 1, charbuf, len_with_tail) != -1))
+		elog(ERROR, "pg_lfind8_idx() found nonexistent element '0x%x'", key + 1);
+
+	memset(charbuf, 0xFF, len_with_tail);
+	/* search with vector operations */
+	keypos = len_no_tail - 1;
+	charbuf[keypos] = key;
+	if (key > 0x00 && (pg_lfind8_idx(key - 1, charbuf, len_no_tail) != -1))
+		elog(ERROR, "pg_lfind8_idx() found nonexistent element '0x%x'", key - 1);
+	if (key < 0xFF && (pg_lfind8_idx(key, charbuf, len_no_tail) != keypos))
+		elog(ERROR, "pg_lfind8_idx() did not find existing element '0x%x'", key);
+	if (key < 0xFE && (pg_lfind8_idx(key + 1, charbuf, len_no_tail) != -1))
+		elog(ERROR, "pg_lfind8_idx() found nonexistent element '0x%x'", key + 1);
+}
+
+PG_FUNCTION_INFO_V1(test_lfind8_idx);
+Datum
+test_lfind8_idx(PG_FUNCTION_ARGS)
+{
+	test_lfind8_idx_internal(0);
+	test_lfind8_idx_internal(1);
+	test_lfind8_idx_internal(0x7F);
+	test_lfind8_idx_internal(0x80);
+	test_lfind8_idx_internal(0x81);
+	test_lfind8_idx_internal(0xFD);
+	test_lfind8_idx_internal(0xFE);
+	test_lfind8_idx_internal(0xFF);
+
+	PG_RETURN_VOID();
+}
+
+PG_FUNCTION_INFO_V1(test_lfind8_ge_idx);
+Datum
+test_lfind8_ge_idx(PG_FUNCTION_ARGS)
+{
+#define TEST_ARRAY_SIZE 135
+	uint8		test_array[TEST_ARRAY_SIZE] = {0};
+
+	test_array[8] = 1;
+	test_array[64] = 3;
+	test_array[TEST_ARRAY_SIZE - 1] = 5;
+
+	if (pg_lfind8_ge_idx(1, test_array, 4) != -1)
+		elog(ERROR, "pg_lfind8_ge_idx found nonexistent element");
+	if (pg_lfind8_ge_idx(1, test_array, TEST_ARRAY_SIZE) != 8)
+		elog(ERROR, "pg_lfind8_ge_idx did not find existing element");
+
+	if (pg_lfind8_ge_idx(2, test_array, 32) != -1)
+		elog(ERROR, "pg_lfind8_ge_idx found nonexistent element");
+	if (pg_lfind8_ge_idx(2, test_array, TEST_ARRAY_SIZE) != 64)
+		elog(ERROR, "pg_lfind8_ge_idx did not find existing element");
+
+	if (pg_lfind8_ge_idx(4, test_array, 96) != -1)
+		elog(ERROR, "pg_lfind8_ge_idx found nonexistent element");
+	if (pg_lfind8_ge_idx(4, test_array, TEST_ARRAY_SIZE) != TEST_ARRAY_SIZE - 1)
+		elog(ERROR, "pg_lfind8_ge_idx did not find existing element");
+
+	if (pg_lfind8_ge_idx(6, test_array, TEST_ARRAY_SIZE) != -1)
+		elog(ERROR, "pg_lfind8_ge_idx found nonexistent element");
+
+	PG_RETURN_VOID();
+}
+
 PG_FUNCTION_INFO_V1(test_lfind32);
 Datum
 test_lfind32(PG_FUNCTION_ARGS)
 {
-#define TEST_ARRAY_SIZE 135
 	uint32		test_array[TEST_ARRAY_SIZE] = {0};
 
 	test_array[8] = 1;
-- 
2.25.1

