Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit 8dc28d7

Browse files
committed
Optimise numeric multiplication using base-NBASE^2 arithmetic.
Currently mul_var() uses the schoolbook multiplication algorithm, which is O(n^2) in the number of NBASE digits. To improve performance for large inputs, convert the inputs to base NBASE^2 before multiplying, which effectively halves the number of digits in each input, theoretically speeding up the computation by a factor of 4. In practice, the actual speedup for large inputs varies between around 3 and 6 times, depending on the system and compiler used. In turn, this significantly reduces the runtime of the numeric_big regression test. For this to work, 64-bit integers are required for the products of base-NBASE^2 digits, so this works best on 64-bit machines, on which it is faster whenever the shorter input has more than 4 or 5 NBASE digits. On 32-bit machines, the additional overheads, especially during carry propagation and the final conversion back to base-NBASE, are significantly higher, and it is only faster when the shorter input has more than around 50 NBASE digits. When the shorter input has more than 6 NBASE digits (so that mul_var_short() cannot be used), but fewer than around 50 NBASE digits, there may be a noticeable slowdown on 32-bit machines. That seems to be an acceptable tradeoff, given the performance gains for other inputs, and the effort that would be required to maintain code specifically targeting 32-bit machines. Joel Jacobson and Dean Rasheed. Discussion: https://postgr.es/m/9d8a4a42-c354-41f3-bbf3-199e1957db97%40app.fastmail.com
1 parent c4e4422 commit 8dc28d7

File tree

1 file changed

+150
-74
lines changed

1 file changed

+150
-74
lines changed

src/backend/utils/adt/numeric.c

Lines changed: 150 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ typedef signed char NumericDigit;
101101
typedef int16 NumericDigit;
102102
#endif
103103

104+
#define NBASE_SQR (NBASE * NBASE)
105+
104106
/*
105107
* The Numeric type as stored on disk.
106108
*
@@ -8668,21 +8670,30 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
86688670
int rscale)
86698671
{
86708672
int res_ndigits;
8673+
int res_ndigitpairs;
86718674
int res_sign;
86728675
int res_weight;
8676+
int pair_offset;
86738677
int maxdigits;
8674-
int *dig;
8675-
int carry;
8676-
int maxdig;
8677-
int newdig;
8678+
int maxdigitpairs;
8679+
uint64 *dig,
8680+
*dig_i1_off;
8681+
uint64 maxdig;
8682+
uint64 carry;
8683+
uint64 newdig;
86788684
int var1ndigits;
86798685
int var2ndigits;
8686+
int var1ndigitpairs;
8687+
int var2ndigitpairs;
86808688
NumericDigit *var1digits;
86818689
NumericDigit *var2digits;
8690+
uint32 var1digitpair;
8691+
uint32 *var2digitpairs;
86828692
NumericDigit *res_digits;
86838693
int i,
86848694
i1,
8685-
i2;
8695+
i2,
8696+
i2limit;
86868697

86878698
/*
86888699
* Arrange for var1 to be the shorter of the two numbers. This improves
@@ -8723,137 +8734,202 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
87238734
return;
87248735
}
87258736

8726-
/* Determine result sign and (maximum possible) weight */
8737+
/* Determine result sign */
87278738
if (var1->sign == var2->sign)
87288739
res_sign = NUMERIC_POS;
87298740
else
87308741
res_sign = NUMERIC_NEG;
8731-
res_weight = var1->weight + var2->weight + 2;
87328742

87338743
/*
8734-
* Determine the number of result digits to compute. If the exact result
8735-
* would have more than rscale fractional digits, truncate the computation
8736-
* with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
8737-
* would only contribute to the right of that. (This will give the exact
8744+
* Determine the number of result digits to compute and the (maximum
8745+
* possible) result weight. If the exact result would have more than
8746+
* rscale fractional digits, truncate the computation with
8747+
* MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that would
8748+
* only contribute to the right of that. (This will give the exact
87388749
* rounded-to-rscale answer unless carries out of the ignored positions
87398750
* would have propagated through more than MUL_GUARD_DIGITS digits.)
87408751
*
87418752
* Note: an exact computation could not produce more than var1ndigits +
8742-
* var2ndigits digits, but we allocate one extra output digit in case
8743-
* rscale-driven rounding produces a carry out of the highest exact digit.
8753+
* var2ndigits digits, but we allocate at least one extra output digit in
8754+
* case rscale-driven rounding produces a carry out of the highest exact
8755+
* digit.
8756+
*
8757+
* The computation itself is done using base-NBASE^2 arithmetic, so we
8758+
* actually process the input digits in pairs, producing a base-NBASE^2
8759+
* intermediate result. This significantly improves performance, since
8760+
* schoolbook multiplication is O(N^2) in the number of input digits, and
8761+
* working in base NBASE^2 effectively halves "N".
8762+
*
8763+
* Note: in a truncated computation, we must compute at least one extra
8764+
* output digit to ensure that all the guard digits are fully computed.
87448765
*/
8745-
res_ndigits = var1ndigits + var2ndigits + 1;
8766+
/* digit pairs in each input */
8767+
var1ndigitpairs = (var1ndigits + 1) / 2;
8768+
var2ndigitpairs = (var2ndigits + 1) / 2;
8769+
8770+
/* digits in exact result */
8771+
res_ndigits = var1ndigits + var2ndigits;
8772+
8773+
/* digit pairs in exact result with at least one extra output digit */
8774+
res_ndigitpairs = res_ndigits / 2 + 1;
8775+
8776+
/* pair offset to align result to end of dig[] */
8777+
pair_offset = res_ndigitpairs - var1ndigitpairs - var2ndigitpairs + 1;
8778+
8779+
/* maximum possible result weight (odd-length inputs shifted up below) */
8780+
res_weight = var1->weight + var2->weight + 1 + 2 * res_ndigitpairs -
8781+
res_ndigits - (var1ndigits & 1) - (var2ndigits & 1);
8782+
8783+
/* rscale-based truncation with at least one extra output digit */
87468784
maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
87478785
MUL_GUARD_DIGITS;
8748-
res_ndigits = Min(res_ndigits, maxdigits);
8786+
maxdigitpairs = maxdigits / 2 + 1;
8787+
8788+
res_ndigitpairs = Min(res_ndigitpairs, maxdigitpairs);
8789+
res_ndigits = 2 * res_ndigitpairs;
87498790

8750-
if (res_ndigits < 3)
8791+
/*
8792+
* In the computation below, digit pair i1 of var1 and digit pair i2 of
8793+
* var2 are multiplied and added to digit i1+i2+pair_offset of dig[]. Thus
8794+
* input digit pairs with index >= res_ndigitpairs - pair_offset don't
8795+
* contribute to the result, and can be ignored.
8796+
*/
8797+
if (res_ndigitpairs <= pair_offset)
87518798
{
87528799
/* All input digits will be ignored; so result is zero */
87538800
zero_var(result);
87548801
result->dscale = rscale;
87558802
return;
87568803
}
8804+
var1ndigitpairs = Min(var1ndigitpairs, res_ndigitpairs - pair_offset);
8805+
var2ndigitpairs = Min(var2ndigitpairs, res_ndigitpairs - pair_offset);
87578806

87588807
/*
8759-
* We do the arithmetic in an array "dig[]" of signed int's. Since
8760-
* INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
8761-
* to avoid normalizing carries immediately.
8808+
* We do the arithmetic in an array "dig[]" of unsigned 64-bit integers.
8809+
* Since PG_UINT64_MAX is much larger than NBASE^4, this gives us a lot of
8810+
* headroom to avoid normalizing carries immediately.
87628811
*
87638812
* maxdig tracks the maximum possible value of any dig[] entry; when this
8764-
* threatens to exceed INT_MAX, we take the time to propagate carries.
8765-
* Furthermore, we need to ensure that overflow doesn't occur during the
8766-
* carry propagation passes either. The carry values could be as much as
8767-
* INT_MAX/NBASE, so really we must normalize when digits threaten to
8768-
* exceed INT_MAX - INT_MAX/NBASE.
8813+
* threatens to exceed PG_UINT64_MAX, we take the time to propagate
8814+
* carries. Furthermore, we need to ensure that overflow doesn't occur
8815+
* during the carry propagation passes either. The carry values could be
8816+
* as much as PG_UINT64_MAX / NBASE^2, so really we must normalize when
8817+
* digits threaten to exceed PG_UINT64_MAX - PG_UINT64_MAX / NBASE^2.
87698818
*
8770-
* To avoid overflow in maxdig itself, it actually represents the max
8771-
* possible value divided by NBASE-1, ie, at the top of the loop it is
8772-
* known that no dig[] entry exceeds maxdig * (NBASE-1).
8819+
* To avoid overflow in maxdig itself, it actually represents the maximum
8820+
* possible value divided by NBASE^2-1, i.e., at the top of the loop it is
8821+
* known that no dig[] entry exceeds maxdig * (NBASE^2-1).
8822+
*
8823+
* The conversion of var1 to base NBASE^2 is done on the fly, as each new
8824+
* digit is required. The digits of var2 are converted upfront, and
8825+
* stored at the end of dig[]. To avoid loss of precision, the input
8826+
* digits are aligned with the start of digit pair array, effectively
8827+
* shifting them up (multiplying by NBASE) if the inputs have an odd
8828+
* number of NBASE digits.
87738829
*/
8774-
dig = (int *) palloc0(res_ndigits * sizeof(int));
8775-
maxdig = 0;
8830+
dig = (uint64 *) palloc(res_ndigitpairs * sizeof(uint64) +
8831+
var2ndigitpairs * sizeof(uint32));
8832+
8833+
/* convert var2 to base NBASE^2, shifting up if its length is odd */
8834+
var2digitpairs = (uint32 *) (dig + res_ndigitpairs);
8835+
8836+
for (i2 = 0; i2 < var2ndigitpairs - 1; i2++)
8837+
var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
8838+
8839+
if (2 * i2 + 1 < var2ndigits)
8840+
var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
8841+
else
8842+
var2digitpairs[i2] = var2digits[2 * i2] * NBASE;
87768843

87778844
/*
8778-
* The least significant digits of var1 should be ignored if they don't
8779-
* contribute directly to the first res_ndigits digits of the result that
8780-
* we are computing.
8845+
* Start by multiplying var2 by the least significant contributing digit
8846+
* pair from var1, storing the results at the end of dig[], and filling
8847+
* the leading digits with zeros.
87818848
*
8782-
* Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
8783-
* i1+i2+2 of the accumulator array, so we need only consider digits of
8784-
* var1 for which i1 <= res_ndigits - 3.
8849+
* The loop here is the same as the inner loop below, except that we set
8850+
* the results in dig[], rather than adding to them. This is the
8851+
* performance bottleneck for multiplication, so we want to keep it simple
8852+
* enough so that it can be auto-vectorized. Accordingly, process the
8853+
* digits left-to-right even though schoolbook multiplication would
8854+
* suggest right-to-left. Since we aren't propagating carries in this
8855+
* loop, the order does not matter.
8856+
*/
8857+
i1 = var1ndigitpairs - 1;
8858+
if (2 * i1 + 1 < var1ndigits)
8859+
var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
8860+
else
8861+
var1digitpair = var1digits[2 * i1] * NBASE;
8862+
maxdig = var1digitpair;
8863+
8864+
i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
8865+
dig_i1_off = &dig[i1 + pair_offset];
8866+
8867+
memset(dig, 0, (i1 + pair_offset) * sizeof(uint64));
8868+
for (i2 = 0; i2 < i2limit; i2++)
8869+
dig_i1_off[i2] = (uint64) var1digitpair * var2digitpairs[i2];
8870+
8871+
/*
8872+
* Next, multiply var2 by the remaining digit pairs from var1, adding the
8873+
* results to dig[] at the appropriate offsets, and normalizing whenever
8874+
* there is a risk of any dig[] entry overflowing.
87858875
*/
8786-
for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
8876+
for (i1 = i1 - 1; i1 >= 0; i1--)
87878877
{
8788-
NumericDigit var1digit = var1digits[i1];
8789-
8790-
if (var1digit == 0)
8878+
var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
8879+
if (var1digitpair == 0)
87918880
continue;
87928881

87938882
/* Time to normalize? */
8794-
maxdig += var1digit;
8795-
if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
8883+
maxdig += var1digitpair;
8884+
if (maxdig > (PG_UINT64_MAX - PG_UINT64_MAX / NBASE_SQR) / (NBASE_SQR - 1))
87968885
{
8797-
/* Yes, do it */
8886+
/* Yes, do it (to base NBASE^2) */
87988887
carry = 0;
8799-
for (i = res_ndigits - 1; i >= 0; i--)
8888+
for (i = res_ndigitpairs - 1; i >= 0; i--)
88008889
{
88018890
newdig = dig[i] + carry;
8802-
if (newdig >= NBASE)
8891+
if (newdig >= NBASE_SQR)
88038892
{
8804-
carry = newdig / NBASE;
8805-
newdig -= carry * NBASE;
8893+
carry = newdig / NBASE_SQR;
8894+
newdig -= carry * NBASE_SQR;
88068895
}
88078896
else
88088897
carry = 0;
88098898
dig[i] = newdig;
88108899
}
88118900
Assert(carry == 0);
88128901
/* Reset maxdig to indicate new worst-case */
8813-
maxdig = 1 + var1digit;
8902+
maxdig = 1 + var1digitpair;
88148903
}
88158904

8816-
/*
8817-
* Add the appropriate multiple of var2 into the accumulator.
8818-
*
8819-
* As above, digits of var2 can be ignored if they don't contribute,
8820-
* so we only include digits for which i1+i2+2 < res_ndigits.
8821-
*
8822-
* This inner loop is the performance bottleneck for multiplication,
8823-
* so we want to keep it simple enough so that it can be
8824-
* auto-vectorized. Accordingly, process the digits left-to-right
8825-
* even though schoolbook multiplication would suggest right-to-left.
8826-
* Since we aren't propagating carries in this loop, the order does
8827-
* not matter.
8828-
*/
8829-
{
8830-
int i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
8831-
int *dig_i1_2 = &dig[i1 + 2];
8905+
/* Multiply and add */
8906+
i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
8907+
dig_i1_off = &dig[i1 + pair_offset];
88328908

8833-
for (i2 = 0; i2 < i2limit; i2++)
8834-
dig_i1_2[i2] += var1digit * var2digits[i2];
8835-
}
8909+
for (i2 = 0; i2 < i2limit; i2++)
8910+
dig_i1_off[i2] += (uint64) var1digitpair * var2digitpairs[i2];
88368911
}
88378912

88388913
/*
8839-
* Now we do a final carry propagation pass to normalize the result, which
8840-
* we combine with storing the result digits into the output. Note that
8841-
* this is still done at full precision w/guard digits.
8914+
* Now we do a final carry propagation pass to normalize back to base
8915+
* NBASE^2, and construct the base-NBASE result digits. Note that this is
8916+
* still done at full precision w/guard digits.
88428917
*/
88438918
alloc_var(result, res_ndigits);
88448919
res_digits = result->digits;
88458920
carry = 0;
8846-
for (i = res_ndigits - 1; i >= 0; i--)
8921+
for (i = res_ndigitpairs - 1; i >= 0; i--)
88478922
{
88488923
newdig = dig[i] + carry;
8849-
if (newdig >= NBASE)
8924+
if (newdig >= NBASE_SQR)
88508925
{
8851-
carry = newdig / NBASE;
8852-
newdig -= carry * NBASE;
8926+
carry = newdig / NBASE_SQR;
8927+
newdig -= carry * NBASE_SQR;
88538928
}
88548929
else
88558930
carry = 0;
8856-
res_digits[i] = newdig;
8931+
res_digits[2 * i + 1] = (NumericDigit) ((uint32) newdig % NBASE);
8932+
res_digits[2 * i] = (NumericDigit) ((uint32) newdig / NBASE);
88578933
}
88588934
Assert(carry == 0);
88598935

0 commit comments

Comments
 (0)