rust_decimal/ops/
mul.rs

1use crate::constants::{BIG_POWERS_10, MAX_I64_SCALE, U32_MAX};
2use crate::decimal::{CalculationResult, Decimal};
3use crate::ops::common::Buf24;
4
5pub(crate) fn mul_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult {
6    if d1.is_zero() || d2.is_zero() {
7        // We should think about this - does zero need to maintain precision? This treats it like
8        // an absolute which I think is ok, especially since we have is_zero() functions etc.
9        return CalculationResult::Ok(Decimal::ZERO);
10    }
11
12    let mut scale = d1.scale() + d2.scale();
13    let negative = d1.is_sign_negative() ^ d2.is_sign_negative();
14    let mut product = Buf24::zero();
15
16    // See if we can optimize this calculation depending on whether the hi bits are set
17    if d1.hi() | d1.mid() == 0 {
18        if d2.hi() | d2.mid() == 0 {
19            // We're multiplying two 32 bit integers, so we can take some liberties to optimize this.
20            let mut low64 = d1.lo() as u64 * d2.lo() as u64;
21            if scale > Decimal::MAX_SCALE {
22                // We've exceeded maximum scale so we need to start reducing the precision (aka
23                // rounding) until we have something that fits.
24                // If we're too big then we effectively round to zero.
25                if scale > Decimal::MAX_SCALE + MAX_I64_SCALE {
26                    return CalculationResult::Ok(Decimal::ZERO);
27                }
28
29                scale -= Decimal::MAX_SCALE + 1;
30                let mut power = BIG_POWERS_10[scale as usize];
31
32                let tmp = low64 / power;
33                let remainder = low64 - tmp * power;
34                low64 = tmp;
35
36                // Round the result. Since the divisor was a power of 10, it's always even.
37                power >>= 1;
38                if remainder >= power && (remainder > power || (low64 as u32 & 1) > 0) {
39                    low64 += 1;
40                }
41
42                scale = Decimal::MAX_SCALE;
43            }
44
45            // Early exit
46            return CalculationResult::Ok(Decimal::from_parts(
47                low64 as u32,
48                (low64 >> 32) as u32,
49                0,
50                negative,
51                scale,
52            ));
53        }
54
55        // We know that the left hand side is just 32 bits but the right hand side is either
56        // 64 or 96 bits.
57        mul_by_32bit_lhs(d1.lo() as u64, d2, &mut product);
58    } else if d2.mid() | d2.hi() == 0 {
59        // We know that the right hand side is just 32 bits.
60        mul_by_32bit_lhs(d2.lo() as u64, d1, &mut product);
61    } else {
62        // We know we're not dealing with simple 32 bit operands on either side.
63        // We compute and accumulate the 9 partial products using long multiplication
64
65        // 1: ll * rl
66        let mut tmp = d1.lo() as u64 * d2.lo() as u64;
67        product.data[0] = tmp as u32;
68
69        // 2: ll * rm
70        let mut tmp2 = (d1.lo() as u64 * d2.mid() as u64).wrapping_add(tmp >> 32);
71
72        // 3: lm * rl
73        tmp = d1.mid() as u64 * d2.lo() as u64;
74        tmp = tmp.wrapping_add(tmp2);
75        product.data[1] = tmp as u32;
76
77        // Detect if carry happened from the wrapping add
78        if tmp < tmp2 {
79            tmp2 = (tmp >> 32) | (1u64 << 32);
80        } else {
81            tmp2 = tmp >> 32;
82        }
83
84        // 4: lm * rm
85        tmp = (d1.mid() as u64 * d2.mid() as u64) + tmp2;
86
87        // If the high bit isn't set then we can stop here. Otherwise, we need to continue calculating
88        // using the high bits.
89        if (d1.hi() | d2.hi()) > 0 {
90            // 5. ll * rh
91            tmp2 = d1.lo() as u64 * d2.hi() as u64;
92            tmp = tmp.wrapping_add(tmp2);
93            // Detect if we carried
94            let mut tmp3 = if tmp < tmp2 { 1 } else { 0 };
95
96            // 6. lh * rl
97            tmp2 = d1.hi() as u64 * d2.lo() as u64;
98            tmp = tmp.wrapping_add(tmp2);
99            product.data[2] = tmp as u32;
100            // Detect if we carried
101            if tmp < tmp2 {
102                tmp3 += 1;
103            }
104            tmp2 = (tmp3 << 32) | (tmp >> 32);
105
106            // 7. lm * rh
107            tmp = d1.mid() as u64 * d2.hi() as u64;
108            tmp = tmp.wrapping_add(tmp2);
109            // Check for carry
110            tmp3 = if tmp < tmp2 { 1 } else { 0 };
111
112            // 8. lh * rm
113            tmp2 = d1.hi() as u64 * d2.mid() as u64;
114            tmp = tmp.wrapping_add(tmp2);
115            product.data[3] = tmp as u32;
116            // Check for carry
117            if tmp < tmp2 {
118                tmp3 += 1;
119            }
120            tmp = (tmp3 << 32) | (tmp >> 32);
121
122            // 9. lh * rh
123            product.set_high64(d1.hi() as u64 * d2.hi() as u64 + tmp);
124        } else {
125            product.set_mid64(tmp);
126        }
127    }
128
129    // We may want to "rescale". This is the case if the mantissa is > 96 bits or if the scale
130    // exceeds the maximum precision.
131    let upper_word = product.upper_word();
132    if upper_word > 2 || scale > Decimal::MAX_SCALE {
133        scale = if let Some(new_scale) = product.rescale(upper_word, scale) {
134            new_scale
135        } else {
136            return CalculationResult::Overflow;
137        }
138    }
139
140    CalculationResult::Ok(Decimal::from_parts(
141        product.data[0],
142        product.data[1],
143        product.data[2],
144        negative,
145        scale,
146    ))
147}
148
149#[inline(always)]
150fn mul_by_32bit_lhs(d1: u64, d2: &Decimal, product: &mut Buf24) {
151    let mut tmp = d1 * d2.lo() as u64;
152    product.data[0] = tmp as u32;
153    tmp = (d1 * d2.mid() as u64).wrapping_add(tmp >> 32);
154    product.data[1] = tmp as u32;
155    tmp >>= 32;
156
157    // If we're multiplying by a 96 bit integer then continue the calculation
158    if d2.hi() > 0 {
159        tmp = tmp.wrapping_add(d1 * d2.hi() as u64);
160        if tmp > U32_MAX {
161            product.set_mid64(tmp);
162        } else {
163            product.data[2] = tmp as u32;
164        }
165    } else {
166        product.data[2] = tmp as u32;
167    }
168}