Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit e813e0e

Browse files
committed
Add optimized functions for linear search within byte arrays
In similar vein to b6ef167, add pg_lfind8() and pg_lfind8_le() to search for bytes equal or less-than-or-equal to a given byte, respectively. To abstract away platform details, add helper functions and typedefs to simd.h. John Naylor and Nathan Bossart, per suggestion from Andres Freund Discussion: https://www.postgresql.org/message-id/CAFBsxsGzaaGLF%3DNuq61iRXTyspbO9rOjhSqFN%3DV6ozzmta5mXg%40mail.gmail.com
1 parent bcc8b14 commit e813e0e

File tree

6 files changed

+358
-10
lines changed

6 files changed

+358
-10
lines changed

src/include/port/pg_lfind.h

+66-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
/*-------------------------------------------------------------------------
22
*
33
* pg_lfind.h
4-
* Optimized linear search routines.
4+
* Optimized linear search routines using SIMD intrinsics where
5+
* available.
56
*
67
* Copyright (c) 2022, PostgreSQL Global Development Group
78
*
@@ -15,6 +16,70 @@
1516

1617
#include "port/simd.h"
1718

19+
/*
20+
* pg_lfind8
21+
*
22+
* Return true if there is an element in 'base' that equals 'key', otherwise
23+
* return false.
24+
*/
25+
static inline bool
26+
pg_lfind8(uint8 key, uint8 *base, uint32 nelem)
27+
{
28+
uint32 i;
29+
30+
/* round down to multiple of vector length */
31+
uint32 tail_idx = nelem & ~(sizeof(Vector8) - 1);
32+
Vector8 chunk;
33+
34+
for (i = 0; i < tail_idx; i += sizeof(Vector8))
35+
{
36+
vector8_load(&chunk, &base[i]);
37+
if (vector8_has(chunk, key))
38+
return true;
39+
}
40+
41+
/* Process the remaining elements one at a time. */
42+
for (; i < nelem; i++)
43+
{
44+
if (key == base[i])
45+
return true;
46+
}
47+
48+
return false;
49+
}
50+
51+
/*
52+
* pg_lfind8_le
53+
*
54+
* Return true if there is an element in 'base' that is less than or equal to
55+
* 'key', otherwise return false.
56+
*/
57+
static inline bool
58+
pg_lfind8_le(uint8 key, uint8 *base, uint32 nelem)
59+
{
60+
uint32 i;
61+
62+
/* round down to multiple of vector length */
63+
uint32 tail_idx = nelem & ~(sizeof(Vector8) - 1);
64+
Vector8 chunk;
65+
66+
for (i = 0; i < tail_idx; i += sizeof(Vector8))
67+
{
68+
vector8_load(&chunk, &base[i]);
69+
if (vector8_has_le(chunk, key))
70+
return true;
71+
}
72+
73+
/* Process the remaining elements one at a time. */
74+
for (; i < nelem; i++)
75+
{
76+
if (base[i] <= key)
77+
return true;
78+
}
79+
80+
return false;
81+
}
82+
1883
/*
1984
* pg_lfind32
2085
*
@@ -26,7 +91,6 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem)
2691
{
2792
uint32 i = 0;
2893

29-
/* Use SIMD intrinsics where available. */
3094
#ifdef USE_SSE2
3195

3296
/*

src/include/port/simd.h

+167-1
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88
*
99
* src/include/port/simd.h
1010
*
11+
* NOTES
12+
* - VectorN in this file refers to a register where the element operands
13+
* are N bits wide. The vector width is platform-specific, so users that care
14+
* about that will need to inspect "sizeof(VectorN)".
15+
*
1116
*-------------------------------------------------------------------------
1217
*/
1318
#ifndef SIMD_H
1419
#define SIMD_H
1520

21+
#if (defined(__x86_64__) || defined(_M_AMD64))
1622
/*
1723
* SSE2 instructions are part of the spec for the 64-bit x86 ISA. We assume
1824
* that compilers targeting this architecture understand SSE2 intrinsics.
@@ -22,9 +28,169 @@
2228
* will allow the use of intrinsics that haven't been enabled at compile
2329
* time.
2430
*/
25-
#if (defined(__x86_64__) || defined(_M_AMD64))
2631
#include <emmintrin.h>
2732
#define USE_SSE2
33+
typedef __m128i Vector8;
34+
35+
#else
36+
/*
37+
* If no SIMD instructions are available, we can in some cases emulate vector
38+
* operations using bitwise operations on unsigned integers.
39+
*/
40+
#define USE_NO_SIMD
41+
typedef uint64 Vector8;
42+
#endif
43+
44+
45+
/* load/store operations */
46+
static inline void vector8_load(Vector8 *v, const uint8 *s);
47+
48+
/* assignment operations */
49+
static inline Vector8 vector8_broadcast(const uint8 c);
50+
51+
/* element-wise comparisons to a scalar */
52+
static inline bool vector8_has(const Vector8 v, const uint8 c);
53+
static inline bool vector8_has_zero(const Vector8 v);
54+
static inline bool vector8_has_le(const Vector8 v, const uint8 c);
55+
56+
57+
/*
58+
* Load a chunk of memory into the given vector.
59+
*/
60+
static inline void
61+
vector8_load(Vector8 *v, const uint8 *s)
62+
{
63+
#if defined(USE_SSE2)
64+
*v = _mm_loadu_si128((const __m128i *) s);
65+
#else
66+
memcpy(v, s, sizeof(Vector8));
2867
#endif
68+
}
69+
70+
71+
/*
72+
* Create a vector with all elements set to the same value.
73+
*/
74+
static inline Vector8
75+
vector8_broadcast(const uint8 c)
76+
{
77+
#if defined(USE_SSE2)
78+
return _mm_set1_epi8(c);
79+
#else
80+
return ~UINT64CONST(0) / 0xFF * c;
81+
#endif
82+
}
83+
84+
/*
85+
* Return true if any elements in the vector are equal to the given scalar.
86+
*/
87+
static inline bool
88+
vector8_has(const Vector8 v, const uint8 c)
89+
{
90+
bool result;
91+
92+
/* pre-compute the result for assert checking */
93+
#ifdef USE_ASSERT_CHECKING
94+
bool assert_result = false;
95+
96+
for (int i = 0; i < sizeof(Vector8); i++)
97+
{
98+
if (((const uint8 *) &v)[i] == c)
99+
{
100+
assert_result = true;
101+
break;
102+
}
103+
}
104+
#endif /* USE_ASSERT_CHECKING */
105+
106+
#if defined(USE_NO_SIMD)
107+
/* any bytes in v equal to c will evaluate to zero via XOR */
108+
result = vector8_has_zero(v ^ vector8_broadcast(c));
109+
#elif defined(USE_SSE2)
110+
result = _mm_movemask_epi8(_mm_cmpeq_epi8(v, vector8_broadcast(c)));
111+
#endif
112+
113+
Assert(assert_result == result);
114+
return result;
115+
}
116+
117+
/*
118+
* Convenience function equivalent to vector8_has(v, 0)
119+
*/
120+
static inline bool
121+
vector8_has_zero(const Vector8 v)
122+
{
123+
#if defined(USE_NO_SIMD)
124+
/*
125+
* We cannot call vector8_has() here, because that would lead to a circular
126+
* definition.
127+
*/
128+
return vector8_has_le(v, 0);
129+
#elif defined(USE_SSE2)
130+
return vector8_has(v, 0);
131+
#endif
132+
}
133+
134+
/*
135+
* Return true if any elements in the vector are less than or equal to the
136+
* given scalar.
137+
*/
138+
static inline bool
139+
vector8_has_le(const Vector8 v, const uint8 c)
140+
{
141+
bool result = false;
142+
#if defined(USE_SSE2)
143+
__m128i sub;
144+
#endif
145+
146+
/* pre-compute the result for assert checking */
147+
#ifdef USE_ASSERT_CHECKING
148+
bool assert_result = false;
149+
150+
for (int i = 0; i < sizeof(Vector8); i++)
151+
{
152+
if (((const uint8 *) &v)[i] <= c)
153+
{
154+
assert_result = true;
155+
break;
156+
}
157+
}
158+
#endif /* USE_ASSERT_CHECKING */
159+
160+
#if defined(USE_NO_SIMD)
161+
162+
/*
163+
* To find bytes <= c, we can use bitwise operations to find bytes < c+1,
164+
* but it only works if c+1 <= 128 and if the highest bit in v is not set.
165+
* Adapted from
166+
* https://graphics.stanford.edu/~seander/bithacks.html#HasLessInWord
167+
*/
168+
if ((int64) v >= 0 && c < 0x80)
169+
result = (v - vector8_broadcast(c + 1)) & ~v & vector8_broadcast(0x80);
170+
else
171+
{
172+
/* one byte at a time */
173+
for (int i = 0; i < sizeof(Vector8); i++)
174+
{
175+
if (((const uint8 *) &v)[i] <= c)
176+
{
177+
result = true;
178+
break;
179+
}
180+
}
181+
}
182+
#elif defined(USE_SSE2)
183+
184+
/*
185+
* Use saturating subtraction to find bytes <= c, which will present as
186+
* NUL bytes in 'sub'.
187+
*/
188+
sub = _mm_subs_epu8(v, vector8_broadcast(c));
189+
result = vector8_has_zero(sub);
190+
#endif
191+
192+
Assert(assert_result == result);
193+
return result;
194+
}
29195

30196
#endif /* SIMD_H */

src/test/modules/test_lfind/expected/test_lfind.out

+15-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,21 @@ CREATE EXTENSION test_lfind;
44
-- the operations complete without crashing or hanging and that none of their
55
-- internal sanity tests fail.
66
--
7-
SELECT test_lfind();
8-
test_lfind
9-
------------
7+
SELECT test_lfind8();
8+
test_lfind8
9+
-------------
10+
11+
(1 row)
12+
13+
SELECT test_lfind8_le();
14+
test_lfind8_le
15+
----------------
16+
17+
(1 row)
18+
19+
SELECT test_lfind32();
20+
test_lfind32
21+
--------------
1022

1123
(1 row)
1224

src/test/modules/test_lfind/sql/test_lfind.sql

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ CREATE EXTENSION test_lfind;
55
-- the operations complete without crashing or hanging and that none of their
66
-- internal sanity tests fail.
77
--
8-
SELECT test_lfind();
8+
SELECT test_lfind8();
9+
SELECT test_lfind8_le();
10+
SELECT test_lfind32();

src/test/modules/test_lfind/test_lfind--1.0.sql

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
44
\echo Use "CREATE EXTENSION test_lfind" to load this file. \quit
55

6-
CREATE FUNCTION test_lfind()
6+
CREATE FUNCTION test_lfind32()
7+
RETURNS pg_catalog.void
8+
AS 'MODULE_PATHNAME' LANGUAGE C;
9+
10+
CREATE FUNCTION test_lfind8()
11+
RETURNS pg_catalog.void
12+
AS 'MODULE_PATHNAME' LANGUAGE C;
13+
14+
CREATE FUNCTION test_lfind8_le()
715
RETURNS pg_catalog.void
816
AS 'MODULE_PATHNAME' LANGUAGE C;

0 commit comments

Comments
 (0)