rust_decimal/ops/
rem.rs

1use crate::constants::{MAX_I32_SCALE, MAX_SCALE_I32, POWERS_10};
2use crate::decimal::{CalculationResult, Decimal};
3use crate::ops::common::{Buf12, Buf16, Buf24, Dec64};
4
5pub(crate) fn rem_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult {
6    if d2.is_zero() {
7        return CalculationResult::DivByZero;
8    }
9    if d1.is_zero() {
10        return CalculationResult::Ok(Decimal::ZERO);
11    }
12
13    // We handle the structs a bit different here. Firstly, we ignore both the sign/scale of d2.
14    // This is because during a remainder operation we do not care about the sign of the divisor
15    // and only concern ourselves with that of the dividend.
16    let mut d1 = Dec64::new(d1);
17    let d2_scale = d2.scale();
18    let mut d2 = Buf12::from_decimal(d2);
19
20    let cmp = crate::ops::cmp::cmp_internal(
21        &d1,
22        &Dec64 {
23            negative: d1.negative,
24            scale: d2_scale,
25            hi: d2.hi(),
26            low64: d2.low64(),
27        },
28    );
29    match cmp {
30        core::cmp::Ordering::Equal => {
31            // Same numbers meaning that remainder is zero
32            return CalculationResult::Ok(Decimal::ZERO);
33        }
34        core::cmp::Ordering::Less => {
35            // d1 < d2, e.g. 1/2. This means that the result is the value of d1
36            return CalculationResult::Ok(d1.to_decimal());
37        }
38        core::cmp::Ordering::Greater => {}
39    }
40
41    // At this point we know that the dividend > divisor and that they are both non-zero.
42    let mut scale = d1.scale as i32 - d2_scale as i32;
43    if scale > 0 {
44        // Scale up the divisor
45        loop {
46            let power = if scale >= MAX_I32_SCALE {
47                POWERS_10[9]
48            } else {
49                POWERS_10[scale as usize]
50            } as u64;
51
52            let mut tmp = d2.lo() as u64 * power;
53            d2.set_lo(tmp as u32);
54            tmp >>= 32;
55            tmp = tmp.wrapping_add((d2.mid() as u64 + ((d2.hi() as u64) << 32)) * power);
56            d2.set_mid(tmp as u32);
57            d2.set_hi((tmp >> 32) as u32);
58
59            // Keep scaling if there is more to go
60            scale -= MAX_I32_SCALE;
61            if scale <= 0 {
62                break;
63            }
64        }
65        scale = 0;
66    }
67
68    loop {
69        // If the dividend is smaller than the divisor then try to scale that up first
70        if scale < 0 {
71            let mut quotient = Buf12 {
72                data: [d1.lo(), d1.mid(), d1.hi],
73            };
74            loop {
75                // Figure out how much we can scale by
76                let power_scale;
77                if let Some(u) = quotient.find_scale(MAX_SCALE_I32 + scale) {
78                    if u >= POWERS_10.len() {
79                        power_scale = 9;
80                    } else {
81                        power_scale = u;
82                    }
83                } else {
84                    return CalculationResult::Overflow;
85                };
86                if power_scale == 0 {
87                    break;
88                }
89                let power = POWERS_10[power_scale] as u64;
90                scale += power_scale as i32;
91
92                let mut tmp = quotient.data[0] as u64 * power;
93                quotient.data[0] = tmp as u32;
94                tmp >>= 32;
95                quotient.set_high64(tmp.wrapping_add(quotient.high64().wrapping_mul(power)));
96                if power_scale != 9 {
97                    break;
98                }
99                if scale >= 0 {
100                    break;
101                }
102            }
103            d1.low64 = quotient.low64();
104            d1.hi = quotient.data[2];
105            d1.scale = d2_scale;
106        }
107
108        // if the high portion is empty then return the modulus of the bottom portion
109        if d1.hi == 0 {
110            d1.low64 %= d2.low64();
111            return CalculationResult::Ok(d1.to_decimal());
112        } else if (d2.mid() | d2.hi()) == 0 {
113            let mut tmp = d1.high64();
114            tmp = ((tmp % d2.lo() as u64) << 32) | (d1.lo() as u64);
115            d1.low64 = tmp % d2.lo() as u64;
116            d1.hi = 0;
117        } else {
118            // Divisor is > 32 bits
119            return rem_full(&d1, &d2, scale);
120        }
121
122        if scale >= 0 {
123            break;
124        }
125    }
126
127    CalculationResult::Ok(d1.to_decimal())
128}
129
130fn rem_full(d1: &Dec64, d2: &Buf12, scale: i32) -> CalculationResult {
131    let mut scale = scale;
132
133    // First normalize the divisor
134    let shift = if d2.hi() == 0 {
135        d2.mid().leading_zeros()
136    } else {
137        d2.hi().leading_zeros()
138    };
139
140    let mut buffer = Buf24::zero();
141    let mut overflow = 0u32;
142    buffer.set_low64(d1.low64 << shift);
143    buffer.set_mid64(((d1.mid() as u64).wrapping_add((d1.hi as u64) << 32)) >> (32 - shift));
144    let mut upper = 3; // We start at 3 due to bit shifting
145
146    while scale < 0 {
147        let power = if -scale >= MAX_I32_SCALE {
148            POWERS_10[9]
149        } else {
150            POWERS_10[-scale as usize]
151        } as u64;
152        let mut tmp64 = buffer.data[0] as u64 * power;
153        buffer.data[0] = tmp64 as u32;
154
155        for (index, part) in buffer.data.iter_mut().enumerate().skip(1) {
156            if index > upper {
157                break;
158            }
159            tmp64 >>= 32;
160            tmp64 = tmp64.wrapping_add((*part as u64).wrapping_mul(power));
161            *part = tmp64 as u32;
162        }
163        // If we have overflow then also process that
164        if upper == 6 {
165            tmp64 >>= 32;
166            tmp64 = tmp64.wrapping_add((overflow as u64).wrapping_mul(power));
167            overflow = tmp64 as u32;
168        }
169
170        // Make sure the high bit is not set
171        if tmp64 > 0x7FFF_FFFF {
172            upper += 1;
173            if upper > 5 {
174                overflow = (tmp64 >> 32) as u32;
175            } else {
176                buffer.data[upper] = (tmp64 >> 32) as u32;
177            }
178        }
179        scale += MAX_I32_SCALE;
180    }
181
182    // TODO: Optimize slice logic
183
184    let mut tmp = Buf16::zero();
185    let divisor = d2.low64() << shift;
186    if d2.hi() == 0 {
187        // Do some division
188        if upper == 6 {
189            upper -= 1;
190
191            tmp.data = [buffer.data[4], buffer.data[5], overflow, 0];
192            tmp.partial_divide_64(divisor);
193            buffer.data[4] = tmp.data[0];
194            buffer.data[5] = tmp.data[1];
195        }
196        if upper == 5 {
197            upper -= 1;
198            tmp.data = [buffer.data[3], buffer.data[4], buffer.data[5], 0];
199            tmp.partial_divide_64(divisor);
200            buffer.data[3] = tmp.data[0];
201            buffer.data[4] = tmp.data[1];
202            buffer.data[5] = tmp.data[2];
203        }
204        if upper == 4 {
205            tmp.data = [buffer.data[2], buffer.data[3], buffer.data[4], 0];
206            tmp.partial_divide_64(divisor);
207            buffer.data[2] = tmp.data[0];
208            buffer.data[3] = tmp.data[1];
209            buffer.data[4] = tmp.data[2];
210        }
211
212        tmp.data = [buffer.data[1], buffer.data[2], buffer.data[3], 0];
213        tmp.partial_divide_64(divisor);
214        buffer.data[1] = tmp.data[0];
215        buffer.data[2] = tmp.data[1];
216        buffer.data[3] = tmp.data[2];
217
218        tmp.data = [buffer.data[0], buffer.data[1], buffer.data[2], 0];
219        tmp.partial_divide_64(divisor);
220        buffer.data[0] = tmp.data[0];
221        buffer.data[1] = tmp.data[1];
222        buffer.data[2] = tmp.data[2];
223
224        let low64 = buffer.low64() >> shift;
225        CalculationResult::Ok(Decimal::from_parts(
226            low64 as u32,
227            (low64 >> 32) as u32,
228            0,
229            d1.negative,
230            d1.scale,
231        ))
232    } else {
233        let divisor_low64 = divisor;
234        let divisor = Buf12 {
235            data: [
236                divisor_low64 as u32,
237                (divisor_low64 >> 32) as u32,
238                (((d2.mid() as u64) + ((d2.hi() as u64) << 32)) >> (32 - shift)) as u32,
239            ],
240        };
241
242        // Do some division
243        if upper == 6 {
244            upper -= 1;
245            tmp.data = [buffer.data[3], buffer.data[4], buffer.data[5], overflow];
246            tmp.partial_divide_96(&divisor);
247            buffer.data[3] = tmp.data[0];
248            buffer.data[4] = tmp.data[1];
249            buffer.data[5] = tmp.data[2];
250        }
251        if upper == 5 {
252            upper -= 1;
253            tmp.data = [buffer.data[2], buffer.data[3], buffer.data[4], buffer.data[5]];
254            tmp.partial_divide_96(&divisor);
255            buffer.data[2] = tmp.data[0];
256            buffer.data[3] = tmp.data[1];
257            buffer.data[4] = tmp.data[2];
258            buffer.data[5] = tmp.data[3];
259        }
260        if upper == 4 {
261            tmp.data = [buffer.data[1], buffer.data[2], buffer.data[3], buffer.data[4]];
262            tmp.partial_divide_96(&divisor);
263            buffer.data[1] = tmp.data[0];
264            buffer.data[2] = tmp.data[1];
265            buffer.data[3] = tmp.data[2];
266            buffer.data[4] = tmp.data[3];
267        }
268
269        tmp.data = [buffer.data[0], buffer.data[1], buffer.data[2], buffer.data[3]];
270        tmp.partial_divide_96(&divisor);
271        buffer.data[0] = tmp.data[0];
272        buffer.data[1] = tmp.data[1];
273        buffer.data[2] = tmp.data[2];
274        buffer.data[3] = tmp.data[3];
275
276        let low64 = (buffer.low64() >> shift) + ((buffer.data[2] as u64) << (32 - shift) << 32);
277        CalculationResult::Ok(Decimal::from_parts(
278            low64 as u32,
279            (low64 >> 32) as u32,
280            buffer.data[2] >> shift,
281            d1.negative,
282            d1.scale,
283        ))
284    }
285}