Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit ca481d3

Browse files
committed
Optimise numeric multiplication for short inputs.
When either input has a small number of digits, and the exact product is requested, the speed of numeric multiplication can be increased significantly by using a faster direct multiplication algorithm. This works by fully computing each result digit in turn, starting with the least significant, and propagating the carry up. This save cycles by not requiring a temporary buffer to store digit products, not making multiple passes over the digits of the longer input, and not requiring separate carry-propagation passes. For now, this is used when the shorter input has 1-4 NBASE digits (up to 13-16 decimal digits), and the longer input is of any size, which covers a lot of common real-world cases. Also, the relative benefit increases as the size of the longer input increases. Possible future work would be to try extending the technique to larger numbers of digits in the shorter input. Joel Jacobson and Dean Rasheed. Discussion: https://postgr.es/m/44d2ffca-d560-4919-b85a-4d07060946aa@app.fastmail.com
1 parent 42de72f commit ca481d3

File tree

1 file changed

+219
-1
lines changed

1 file changed

+219
-1
lines changed

src/backend/utils/adt/numeric.c

+219-1
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,8 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2,
558558
static void mul_var(const NumericVar *var1, const NumericVar *var2,
559559
NumericVar *result,
560560
int rscale);
561+
static void mul_var_short(const NumericVar *var1, const NumericVar *var2,
562+
NumericVar *result);
561563
static void div_var(const NumericVar *var1, const NumericVar *var2,
562564
NumericVar *result,
563565
int rscale, bool round);
@@ -8722,14 +8724,24 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
87228724
var1digits = var1->digits;
87238725
var2digits = var2->digits;
87248726

8725-
if (var1ndigits == 0 || var2ndigits == 0)
8727+
if (var1ndigits == 0)
87268728
{
87278729
/* one or both inputs is zero; so is result */
87288730
zero_var(result);
87298731
result->dscale = rscale;
87308732
return;
87318733
}
87328734

8735+
/*
8736+
* If var1 has 1-4 digits and the exact result was requested, delegate to
8737+
* mul_var_short() which uses a faster direct multiplication algorithm.
8738+
*/
8739+
if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale)
8740+
{
8741+
mul_var_short(var1, var2, result);
8742+
return;
8743+
}
8744+
87338745
/* Determine result sign and (maximum possible) weight */
87348746
if (var1->sign == var2->sign)
87358747
res_sign = NUMERIC_POS;
@@ -8880,6 +8892,212 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
88808892
}
88818893

88828894

8895+
/*
8896+
* mul_var_short() -
8897+
*
8898+
* Special-case multiplication function used when var1 has 1-4 digits, var2
8899+
* has at least as many digits as var1, and the exact product var1 * var2 is
8900+
* requested.
8901+
*/
8902+
static void
8903+
mul_var_short(const NumericVar *var1, const NumericVar *var2,
8904+
NumericVar *result)
8905+
{
8906+
int var1ndigits = var1->ndigits;
8907+
int var2ndigits = var2->ndigits;
8908+
NumericDigit *var1digits = var1->digits;
8909+
NumericDigit *var2digits = var2->digits;
8910+
int res_sign;
8911+
int res_weight;
8912+
int res_ndigits;
8913+
NumericDigit *res_buf;
8914+
NumericDigit *res_digits;
8915+
uint32 carry;
8916+
uint32 term;
8917+
8918+
/* Check preconditions */
8919+
Assert(var1ndigits >= 1);
8920+
Assert(var1ndigits <= 4);
8921+
Assert(var2ndigits >= var1ndigits);
8922+
8923+
/*
8924+
* Determine the result sign, weight, and number of digits to calculate.
8925+
* The weight figured here is correct if the product has no leading zero
8926+
* digits; otherwise strip_var() will fix things up. Note that, unlike
8927+
* mul_var(), we do not need to allocate an extra output digit, because we
8928+
* are not rounding here.
8929+
*/
8930+
if (var1->sign == var2->sign)
8931+
res_sign = NUMERIC_POS;
8932+
else
8933+
res_sign = NUMERIC_NEG;
8934+
res_weight = var1->weight + var2->weight + 1;
8935+
res_ndigits = var1ndigits + var2ndigits;
8936+
8937+
/* Allocate result digit array */
8938+
res_buf = digitbuf_alloc(res_ndigits + 1);
8939+
res_buf[0] = 0; /* spare digit for later rounding */
8940+
res_digits = res_buf + 1;
8941+
8942+
/*
8943+
* Compute the result digits in reverse, in one pass, propagating the
8944+
* carry up as we go. The i'th result digit consists of the sum of the
8945+
* products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1.
8946+
*/
8947+
switch (var1ndigits)
8948+
{
8949+
case 1:
8950+
/* ---------
8951+
* 1-digit case:
8952+
* var1ndigits = 1
8953+
* var2ndigits >= 1
8954+
* res_ndigits = var2ndigits + 1
8955+
* ----------
8956+
*/
8957+
carry = 0;
8958+
for (int i = res_ndigits - 2; i >= 0; i--)
8959+
{
8960+
term = (uint32) var1digits[0] * var2digits[i] + carry;
8961+
res_digits[i + 1] = (NumericDigit) (term % NBASE);
8962+
carry = term / NBASE;
8963+
}
8964+
res_digits[0] = (NumericDigit) carry;
8965+
break;
8966+
8967+
case 2:
8968+
/* ---------
8969+
* 2-digit case:
8970+
* var1ndigits = 2
8971+
* var2ndigits >= 2
8972+
* res_ndigits = var2ndigits + 2
8973+
* ----------
8974+
*/
8975+
/* last result digit and carry */
8976+
term = (uint32) var1digits[1] * var2digits[res_ndigits - 3];
8977+
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
8978+
carry = term / NBASE;
8979+
8980+
/* remaining digits, except for the first two */
8981+
for (int i = res_ndigits - 3; i >= 1; i--)
8982+
{
8983+
term = (uint32) var1digits[0] * var2digits[i] +
8984+
(uint32) var1digits[1] * var2digits[i - 1] + carry;
8985+
res_digits[i + 1] = (NumericDigit) (term % NBASE);
8986+
carry = term / NBASE;
8987+
}
8988+
8989+
/* first two digits */
8990+
term = (uint32) var1digits[0] * var2digits[0] + carry;
8991+
res_digits[1] = (NumericDigit) (term % NBASE);
8992+
res_digits[0] = (NumericDigit) (term / NBASE);
8993+
break;
8994+
8995+
case 3:
8996+
/* ---------
8997+
* 3-digit case:
8998+
* var1ndigits = 3
8999+
* var2ndigits >= 3
9000+
* res_ndigits = var2ndigits + 3
9001+
* ----------
9002+
*/
9003+
/* last two result digits */
9004+
term = (uint32) var1digits[2] * var2digits[res_ndigits - 4];
9005+
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
9006+
carry = term / NBASE;
9007+
9008+
term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] +
9009+
(uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry;
9010+
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
9011+
carry = term / NBASE;
9012+
9013+
/* remaining digits, except for the first three */
9014+
for (int i = res_ndigits - 4; i >= 2; i--)
9015+
{
9016+
term = (uint32) var1digits[0] * var2digits[i] +
9017+
(uint32) var1digits[1] * var2digits[i - 1] +
9018+
(uint32) var1digits[2] * var2digits[i - 2] + carry;
9019+
res_digits[i + 1] = (NumericDigit) (term % NBASE);
9020+
carry = term / NBASE;
9021+
}
9022+
9023+
/* first three digits */
9024+
term = (uint32) var1digits[0] * var2digits[1] +
9025+
(uint32) var1digits[1] * var2digits[0] + carry;
9026+
res_digits[2] = (NumericDigit) (term % NBASE);
9027+
carry = term / NBASE;
9028+
9029+
term = (uint32) var1digits[0] * var2digits[0] + carry;
9030+
res_digits[1] = (NumericDigit) (term % NBASE);
9031+
res_digits[0] = (NumericDigit) (term / NBASE);
9032+
break;
9033+
9034+
case 4:
9035+
/* ---------
9036+
* 4-digit case:
9037+
* var1ndigits = 4
9038+
* var2ndigits >= 4
9039+
* res_ndigits = var2ndigits + 4
9040+
* ----------
9041+
*/
9042+
/* last three result digits */
9043+
term = (uint32) var1digits[3] * var2digits[res_ndigits - 5];
9044+
res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE);
9045+
carry = term / NBASE;
9046+
9047+
term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] +
9048+
(uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry;
9049+
res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
9050+
carry = term / NBASE;
9051+
9052+
term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] +
9053+
(uint32) var1digits[2] * var2digits[res_ndigits - 6] +
9054+
(uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry;
9055+
res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
9056+
carry = term / NBASE;
9057+
9058+
/* remaining digits, except for the first four */
9059+
for (int i = res_ndigits - 5; i >= 3; i--)
9060+
{
9061+
term = (uint32) var1digits[0] * var2digits[i] +
9062+
(uint32) var1digits[1] * var2digits[i - 1] +
9063+
(uint32) var1digits[2] * var2digits[i - 2] +
9064+
(uint32) var1digits[3] * var2digits[i - 3] + carry;
9065+
res_digits[i + 1] = (NumericDigit) (term % NBASE);
9066+
carry = term / NBASE;
9067+
}
9068+
9069+
/* first four digits */
9070+
term = (uint32) var1digits[0] * var2digits[2] +
9071+
(uint32) var1digits[1] * var2digits[1] +
9072+
(uint32) var1digits[2] * var2digits[0] + carry;
9073+
res_digits[3] = (NumericDigit) (term % NBASE);
9074+
carry = term / NBASE;
9075+
9076+
term = (uint32) var1digits[0] * var2digits[1] +
9077+
(uint32) var1digits[1] * var2digits[0] + carry;
9078+
res_digits[2] = (NumericDigit) (term % NBASE);
9079+
carry = term / NBASE;
9080+
9081+
term = (uint32) var1digits[0] * var2digits[0] + carry;
9082+
res_digits[1] = (NumericDigit) (term % NBASE);
9083+
res_digits[0] = (NumericDigit) (term / NBASE);
9084+
break;
9085+
}
9086+
9087+
/* Store the product in result */
9088+
digitbuf_free(result->buf);
9089+
result->ndigits = res_ndigits;
9090+
result->buf = res_buf;
9091+
result->digits = res_digits;
9092+
result->weight = res_weight;
9093+
result->sign = res_sign;
9094+
result->dscale = var1->dscale + var2->dscale;
9095+
9096+
/* Strip leading and trailing zeroes */
9097+
strip_var(result);
9098+
}
9099+
9100+
88839101
/*
88849102
* div_var() -
88859103
*

0 commit comments

Comments
 (0)