rust_decimal/ops/
add.rs

1use crate::constants::{MAX_I32_SCALE, POWERS_10, SCALE_MASK, SCALE_SHIFT, SIGN_MASK, U32_MASK, U32_MAX};
2use crate::decimal::{CalculationResult, Decimal};
3use crate::ops::common::{Buf24, Dec64};
4
5pub(crate) fn add_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult {
6    add_sub_internal(d1, d2, false)
7}
8
9pub(crate) fn sub_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult {
10    add_sub_internal(d1, d2, true)
11}
12
13#[inline]
14fn add_sub_internal(d1: &Decimal, d2: &Decimal, subtract: bool) -> CalculationResult {
15    if d1.is_zero() {
16        // 0 - x or 0 + x
17        let mut result = *d2;
18        if subtract && !d2.is_zero() {
19            result.set_sign_negative(d2.is_sign_positive());
20        }
21        return CalculationResult::Ok(result);
22    }
23    if d2.is_zero() {
24        // x - 0 or x + 0
25        return CalculationResult::Ok(*d1);
26    }
27
28    // Work out whether we need to rescale and/or if it's a subtract still given the signs of the
29    // numbers.
30    let flags = d1.flags() ^ d2.flags();
31    let subtract = subtract ^ ((flags & SIGN_MASK) != 0);
32    let rescale = (flags & SCALE_MASK) > 0;
33
34    // We optimize towards using 32 bit logic as much as possible. It's noticeably faster at
35    // scale, even on 64 bit machines
36    if d1.mid() | d1.hi() == 0 && d2.mid() | d2.hi() == 0 {
37        // We'll try to rescale, however we may end up with 64 bit (or more) numbers
38        // If we do, we'll choose a different flow than fast_add
39        if rescale {
40            // This is less optimized if we scale to a 64 bit integer. We can add some further logic
41            // here later on.
42            let rescale_factor = ((d2.flags() & SCALE_MASK) as i32 - (d1.flags() & SCALE_MASK) as i32) >> SCALE_SHIFT;
43            if rescale_factor < 0 {
44                // We try to rescale the rhs
45                if let Some(rescaled) = rescale32(d2.lo(), -rescale_factor) {
46                    return fast_add(d1.lo(), rescaled, d1.flags(), subtract);
47                }
48            } else {
49                // We try to rescale the lhs
50                if let Some(rescaled) = rescale32(d1.lo(), rescale_factor) {
51                    return fast_add(
52                        rescaled,
53                        d2.lo(),
54                        (d2.flags() & SCALE_MASK) | (d1.flags() & SIGN_MASK),
55                        subtract,
56                    );
57                }
58            }
59        } else {
60            return fast_add(d1.lo(), d2.lo(), d1.flags(), subtract);
61        }
62    }
63
64    // Continue on with the slower 64 bit method
65    let d1 = Dec64::new(d1);
66    let d2 = Dec64::new(d2);
67
68    // If we're not the same scale then make sure we're there first before starting addition
69    if rescale {
70        let rescale_factor = d2.scale as i32 - d1.scale as i32;
71        if rescale_factor < 0 {
72            let negative = subtract ^ d1.negative;
73            let scale = d1.scale;
74            unaligned_add(d2, d1, negative, scale, -rescale_factor, subtract)
75        } else {
76            let negative = d1.negative;
77            let scale = d2.scale;
78            unaligned_add(d1, d2, negative, scale, rescale_factor, subtract)
79        }
80    } else {
81        let neg = d1.negative;
82        let scale = d1.scale;
83        aligned_add(d1, d2, neg, scale, subtract)
84    }
85}
86
87#[inline(always)]
88fn rescale32(num: u32, rescale_factor: i32) -> Option<u32> {
89    if rescale_factor > MAX_I32_SCALE {
90        return None;
91    }
92    num.checked_mul(POWERS_10[rescale_factor as usize])
93}
94
95fn fast_add(lo1: u32, lo2: u32, flags: u32, subtract: bool) -> CalculationResult {
96    if subtract {
97        // Sub can't overflow because we're ensuring the bigger number always subtracts the smaller number
98        if lo1 < lo2 {
99            return CalculationResult::Ok(Decimal::from_parts_raw(lo2 - lo1, 0, 0, flags ^ SIGN_MASK));
100        }
101        return CalculationResult::Ok(Decimal::from_parts_raw(lo1 - lo2, 0, 0, flags));
102    }
103    // Add can overflow however, so we check for that explicitly
104    let lo = lo1.wrapping_add(lo2);
105    let mid = if lo < lo1 { 1 } else { 0 };
106    CalculationResult::Ok(Decimal::from_parts_raw(lo, mid, 0, flags))
107}
108
109fn aligned_add(lhs: Dec64, rhs: Dec64, negative: bool, scale: u32, subtract: bool) -> CalculationResult {
110    if subtract {
111        // Signs differ, so subtract
112        let mut result = Dec64 {
113            negative,
114            scale,
115            low64: lhs.low64.wrapping_sub(rhs.low64),
116            hi: lhs.hi.wrapping_sub(rhs.hi),
117        };
118
119        // Check for carry
120        if result.low64 > lhs.low64 {
121            result.hi = result.hi.wrapping_sub(1);
122            if result.hi >= lhs.hi {
123                flip_sign(&mut result);
124            }
125        } else if result.hi > lhs.hi {
126            flip_sign(&mut result);
127        }
128        CalculationResult::Ok(result.to_decimal())
129    } else {
130        // Signs are the same, so add
131        let mut result = Dec64 {
132            negative,
133            scale,
134            low64: lhs.low64.wrapping_add(rhs.low64),
135            hi: lhs.hi.wrapping_add(rhs.hi),
136        };
137
138        // Check for carry
139        if result.low64 < lhs.low64 {
140            result.hi = result.hi.wrapping_add(1);
141            if result.hi <= lhs.hi {
142                if result.scale == 0 {
143                    return CalculationResult::Overflow;
144                }
145                reduce_scale(&mut result);
146            }
147        } else if result.hi < lhs.hi {
148            if result.scale == 0 {
149                return CalculationResult::Overflow;
150            }
151            reduce_scale(&mut result);
152        }
153        CalculationResult::Ok(result.to_decimal())
154    }
155}
156
157fn flip_sign(result: &mut Dec64) {
158    // Bitwise not the high portion
159    result.hi = !result.hi;
160    let low64 = ((result.low64 as i64).wrapping_neg()) as u64;
161    if low64 == 0 {
162        result.hi += 1;
163    }
164    result.low64 = low64;
165    result.negative = !result.negative;
166}
167
168fn reduce_scale(result: &mut Dec64) {
169    let mut low64 = result.low64;
170    let mut hi = result.hi;
171
172    let mut num = (hi as u64) + (1u64 << 32);
173    hi = (num / 10u64) as u32;
174    num = ((num - (hi as u64) * 10u64) << 32) + (low64 >> 32);
175    let mut div = (num / 10) as u32;
176    num = ((num - (div as u64) * 10u64) << 32) + (low64 & U32_MASK);
177    low64 = (div as u64) << 32;
178    div = (num / 10u64) as u32;
179    low64 = low64.wrapping_add(div as u64);
180    let remainder = (num as u32).wrapping_sub(div.wrapping_mul(10));
181
182    // Finally, round. This is optimizing slightly toward non-rounded numbers
183    if remainder >= 5 && (remainder > 5 || (low64 & 1) > 0) {
184        low64 = low64.wrapping_add(1);
185        if low64 == 0 {
186            hi += 1;
187        }
188    }
189
190    result.low64 = low64;
191    result.hi = hi;
192    result.scale -= 1;
193}
194
195// Assumption going into this function is that the LHS is the larger number and will "absorb" the
196// smaller number.
197fn unaligned_add(
198    lhs: Dec64,
199    rhs: Dec64,
200    negative: bool,
201    scale: u32,
202    rescale_factor: i32,
203    subtract: bool,
204) -> CalculationResult {
205    let mut lhs = lhs;
206    let mut low64 = lhs.low64;
207    let mut high = lhs.hi;
208    let mut rescale_factor = rescale_factor;
209
210    // First off, we see if we can get away with scaling small amounts (or none at all)
211    if high == 0 {
212        if low64 <= U32_MAX {
213            // We know it's not zero, so we start scaling.
214            // Start with reducing the scale down for the low portion
215            while low64 <= U32_MAX {
216                if rescale_factor <= MAX_I32_SCALE {
217                    low64 *= POWERS_10[rescale_factor as usize] as u64;
218                    lhs.low64 = low64;
219                    return aligned_add(lhs, rhs, negative, scale, subtract);
220                }
221                rescale_factor -= MAX_I32_SCALE;
222                low64 *= POWERS_10[9] as u64;
223            }
224        }
225
226        // Reduce the scale for the high portion
227        while high == 0 {
228            let power = if rescale_factor <= MAX_I32_SCALE {
229                POWERS_10[rescale_factor as usize] as u64
230            } else {
231                POWERS_10[9] as u64
232            };
233
234            let tmp_low = (low64 & U32_MASK) * power;
235            let tmp_hi = (low64 >> 32) * power + (tmp_low >> 32);
236            low64 = (tmp_low & U32_MASK) + (tmp_hi << 32);
237            high = (tmp_hi >> 32) as u32;
238            rescale_factor -= MAX_I32_SCALE;
239            if rescale_factor <= 0 {
240                lhs.low64 = low64;
241                lhs.hi = high;
242                return aligned_add(lhs, rhs, negative, scale, subtract);
243            }
244        }
245    }
246
247    // See if we can get away with keeping it in the 96 bits. Otherwise, we need a buffer
248    let mut tmp64: u64;
249    loop {
250        let power = if rescale_factor <= MAX_I32_SCALE {
251            POWERS_10[rescale_factor as usize] as u64
252        } else {
253            POWERS_10[9] as u64
254        };
255
256        let tmp_low = (low64 & U32_MASK) * power;
257        tmp64 = (low64 >> 32) * power + (tmp_low >> 32);
258        low64 = (tmp_low & U32_MASK) + (tmp64 << 32);
259        tmp64 >>= 32;
260        tmp64 += (high as u64) * power;
261
262        rescale_factor -= MAX_I32_SCALE;
263
264        if tmp64 > U32_MAX || scale > Decimal::MAX_SCALE {
265            break;
266        } else {
267            high = tmp64 as u32;
268            if rescale_factor <= 0 {
269                lhs.low64 = low64;
270                lhs.hi = high;
271                return aligned_add(lhs, rhs, negative, scale, subtract);
272            }
273        }
274    }
275
276    let mut buffer = Buf24::zero();
277    buffer.set_low64(low64);
278    buffer.set_mid64(tmp64);
279
280    let mut upper_word = buffer.upper_word();
281    while rescale_factor > 0 {
282        let power = if rescale_factor <= MAX_I32_SCALE {
283            POWERS_10[rescale_factor as usize] as u64
284        } else {
285            POWERS_10[9] as u64
286        };
287        tmp64 = 0;
288        for (index, part) in buffer.data.iter_mut().enumerate() {
289            tmp64 = tmp64.wrapping_add((*part as u64) * power);
290            *part = tmp64 as u32;
291            tmp64 >>= 32;
292            if index + 1 > upper_word {
293                break;
294            }
295        }
296
297        if tmp64 & U32_MASK > 0 {
298            // Extend the result
299            upper_word += 1;
300            buffer.data[upper_word] = tmp64 as u32;
301        }
302
303        rescale_factor -= MAX_I32_SCALE;
304    }
305
306    // Do the add
307    tmp64 = buffer.low64();
308    low64 = rhs.low64;
309    let tmp_hi = buffer.data[2];
310    high = rhs.hi;
311
312    if subtract {
313        low64 = tmp64.wrapping_sub(low64);
314        high = tmp_hi.wrapping_sub(high);
315
316        // Check for carry
317        let carry = if low64 > tmp64 {
318            high = high.wrapping_sub(1);
319            high >= tmp_hi
320        } else {
321            high > tmp_hi
322        };
323
324        if carry {
325            for part in buffer.data.iter_mut().skip(3) {
326                *part = part.wrapping_sub(1);
327                if *part > 0 {
328                    break;
329                }
330            }
331
332            if buffer.data[upper_word] == 0 && upper_word < 3 {
333                return CalculationResult::Ok(Decimal::from_parts(
334                    low64 as u32,
335                    (low64 >> 32) as u32,
336                    high,
337                    negative,
338                    scale,
339                ));
340            }
341        }
342    } else {
343        low64 = low64.wrapping_add(tmp64);
344        high = high.wrapping_add(tmp_hi);
345
346        // Check for carry
347        let carry = if low64 < tmp64 {
348            high = high.wrapping_add(1);
349            high <= tmp_hi
350        } else {
351            high < tmp_hi
352        };
353
354        if carry {
355            for (index, part) in buffer.data.iter_mut().enumerate().skip(3) {
356                if upper_word < index {
357                    *part = 1;
358                    upper_word = index;
359                    break;
360                }
361                *part = part.wrapping_add(1);
362                if *part > 0 {
363                    break;
364                }
365            }
366        }
367    }
368
369    buffer.set_low64(low64);
370    buffer.data[2] = high;
371    if let Some(scale) = buffer.rescale(upper_word, scale) {
372        CalculationResult::Ok(Decimal::from_parts(
373            buffer.data[0],
374            buffer.data[1],
375            buffer.data[2],
376            negative,
377            scale,
378        ))
379    } else {
380        CalculationResult::Overflow
381    }
382}