rust_decimal/
str.rs

1use crate::{
2    constants::{BYTES_TO_OVERFLOW_U64, MAX_SCALE, MAX_STR_BUFFER_SIZE, OVERFLOW_U96, WILL_OVERFLOW_U64},
3    error::{tail_error, Error},
4    ops::array::{add_by_internal_flattened, add_one_internal, div_by_u32, is_all_zero, mul_by_u32},
5    Decimal,
6};
7
8use arrayvec::{ArrayString, ArrayVec};
9
10use alloc::{string::String, vec::Vec};
11use core::fmt;
12
13// impl that doesn't allocate for serialization purposes.
14pub(crate) fn to_str_internal(
15    value: &Decimal,
16    append_sign: bool,
17    precision: Option<usize>,
18) -> (ArrayString<MAX_STR_BUFFER_SIZE>, Option<usize>) {
19    // Get the scale - where we need to put the decimal point
20    let scale = value.scale() as usize;
21
22    // Convert to a string and manipulate that (neg at front, inject decimal)
23    let mut chars = ArrayVec::<_, MAX_STR_BUFFER_SIZE>::new();
24    let mut working = value.mantissa_array3();
25    while !is_all_zero(&working) {
26        let remainder = div_by_u32(&mut working, 10u32);
27        chars.push(char::from(b'0' + remainder as u8));
28    }
29    while scale > chars.len() {
30        chars.push('0');
31    }
32
33    let (prec, additional) = match precision {
34        Some(prec) => {
35            let max: usize = MAX_SCALE.into();
36            if prec > max {
37                (max, Some(prec - max))
38            } else {
39                (prec, None)
40            }
41        }
42        None => (scale, None),
43    };
44
45    let len = chars.len();
46    let whole_len = len - scale;
47    let mut rep = ArrayString::new();
48    // Append the negative sign if necessary while also keeping track of the length of an "empty" string representation
49    let empty_len = if append_sign && value.is_sign_negative() {
50        rep.push('-');
51        1
52    } else {
53        0
54    };
55    for i in 0..whole_len + prec {
56        if i == len - scale {
57            if i == 0 {
58                rep.push('0');
59            }
60            rep.push('.');
61        }
62
63        if i >= len {
64            rep.push('0');
65        } else {
66            let c = chars[len - i - 1];
67            rep.push(c);
68        }
69    }
70
71    // corner case for when we truncated everything in a low fractional
72    if rep.len() == empty_len {
73        rep.push('0');
74    }
75
76    (rep, additional)
77}
78
79pub(crate) fn fmt_scientific_notation(
80    value: &Decimal,
81    exponent_symbol: &str,
82    f: &mut fmt::Formatter<'_>,
83) -> fmt::Result {
84    #[cfg(not(feature = "std"))]
85    use alloc::string::ToString;
86
87    // Get the scale - this is the e value. With multiples of 10 this may get bigger.
88    let mut exponent = -(value.scale() as isize);
89
90    // Convert the integral to a string
91    let mut chars = Vec::new();
92    let mut working = value.mantissa_array3();
93    while !is_all_zero(&working) {
94        let remainder = div_by_u32(&mut working, 10u32);
95        chars.push(char::from(b'0' + remainder as u8));
96    }
97
98    // First of all, apply scientific notation rules. That is:
99    //  1. If non-zero digit comes first, move decimal point left so that e is a positive integer
100    //  2. If decimal point comes first, move decimal point right until after the first non-zero digit
101    // Since decimal notation naturally lends itself this way, we just need to inject the decimal
102    // point in the right place and adjust the exponent accordingly.
103
104    let len = chars.len();
105    let mut rep;
106    // We either are operating with a precision specified, or on defaults. Defaults will perform "smart"
107    // reduction of precision.
108    if let Some(precision) = f.precision() {
109        if len > 1 {
110            // If we're zero precision AND it's trailing zeros then strip them
111            if precision == 0 && chars.iter().take(len - 1).all(|c| *c == '0') {
112                rep = chars.iter().skip(len - 1).collect::<String>();
113            } else {
114                // We may still be zero precision, however we aren't trailing zeros
115                if precision > 0 {
116                    chars.insert(len - 1, '.');
117                }
118                rep = chars
119                    .iter()
120                    .rev()
121                    // Add on extra zeros according to the precision. At least one, since we added a decimal place.
122                    .chain(core::iter::repeat(&'0'))
123                    .take(if precision == 0 { 1 } else { 2 + precision })
124                    .collect::<String>();
125            }
126            exponent += (len - 1) as isize;
127        } else if precision > 0 {
128            // We have precision that we want to add
129            chars.push('.');
130            rep = chars
131                .iter()
132                .chain(core::iter::repeat(&'0'))
133                .take(2 + precision)
134                .collect::<String>();
135        } else {
136            rep = chars.iter().collect::<String>();
137        }
138    } else if len > 1 {
139        // If the number is just trailing zeros then we treat it like 0 precision
140        if chars.iter().take(len - 1).all(|c| *c == '0') {
141            rep = chars.iter().skip(len - 1).collect::<String>();
142        } else {
143            // Otherwise, we need to insert a decimal place and make it a scientific number
144            chars.insert(len - 1, '.');
145            rep = chars.iter().rev().collect::<String>();
146        }
147        exponent += (len - 1) as isize;
148    } else {
149        rep = chars.iter().collect::<String>();
150    }
151
152    rep.push_str(exponent_symbol);
153    rep.push_str(&exponent.to_string());
154    f.pad_integral(value.is_sign_positive(), "", &rep)
155}
156
157// dedicated implementation for the most common case.
158#[inline]
159pub(crate) fn parse_str_radix_10(str: &str) -> Result<Decimal, Error> {
160    let bytes = str.as_bytes();
161    if bytes.len() < BYTES_TO_OVERFLOW_U64 {
162        parse_str_radix_10_dispatch::<false, true>(bytes)
163    } else {
164        parse_str_radix_10_dispatch::<true, true>(bytes)
165    }
166}
167
168#[inline]
169pub(crate) fn parse_str_radix_10_exact(str: &str) -> Result<Decimal, Error> {
170    let bytes = str.as_bytes();
171    if bytes.len() < BYTES_TO_OVERFLOW_U64 {
172        parse_str_radix_10_dispatch::<false, false>(bytes)
173    } else {
174        parse_str_radix_10_dispatch::<true, false>(bytes)
175    }
176}
177
178#[inline]
179fn parse_str_radix_10_dispatch<const BIG: bool, const ROUND: bool>(bytes: &[u8]) -> Result<Decimal, Error> {
180    match bytes {
181        [b, rest @ ..] => byte_dispatch_u64::<false, false, false, BIG, true, ROUND>(rest, 0, 0, *b),
182        [] => tail_error("Invalid decimal: empty"),
183    }
184}
185
186#[inline]
187fn overflow_64(val: u64) -> bool {
188    val >= WILL_OVERFLOW_U64
189}
190
191#[inline]
192pub fn overflow_128(val: u128) -> bool {
193    val >= OVERFLOW_U96
194}
195
196/// Dispatch the next byte:
197///
198/// * POINT - a decimal point has been seen
199/// * NEG - we've encountered a `-` and the number is negative
200/// * HAS - a digit has been encountered (when HAS is false it's invalid)
201/// * BIG - a number that uses 96 bits instead of only 64 bits
202/// * FIRST - true if it is the first byte in the string
203#[inline]
204fn dispatch_next<const POINT: bool, const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>(
205    bytes: &[u8],
206    data64: u64,
207    scale: u8,
208) -> Result<Decimal, Error> {
209    if let Some((next, bytes)) = bytes.split_first() {
210        byte_dispatch_u64::<POINT, NEG, HAS, BIG, false, ROUND>(bytes, data64, scale, *next)
211    } else {
212        handle_data::<NEG, HAS>(data64 as u128, scale)
213    }
214}
215
216/// Dispatch the next non-digit byte:
217///
218/// * POINT - a decimal point has been seen
219/// * NEG - we've encountered a `-` and the number is negative
220/// * HAS - a digit has been encountered (when HAS is false it's invalid)
221/// * BIG - a number that uses 96 bits instead of only 64 bits
222/// * FIRST - true if it is the first byte in the string
223/// * ROUND - attempt to round underflow
224#[inline(never)]
225fn non_digit_dispatch_u64<
226    const POINT: bool,
227    const NEG: bool,
228    const HAS: bool,
229    const BIG: bool,
230    const FIRST: bool,
231    const ROUND: bool,
232>(
233    bytes: &[u8],
234    data64: u64,
235    scale: u8,
236    b: u8,
237) -> Result<Decimal, Error> {
238    match b {
239        b'-' if FIRST && !HAS => dispatch_next::<false, true, false, BIG, ROUND>(bytes, data64, scale),
240        b'+' if FIRST && !HAS => dispatch_next::<false, false, false, BIG, ROUND>(bytes, data64, scale),
241        b'_' if HAS => handle_separator::<POINT, NEG, BIG, ROUND>(bytes, data64, scale),
242        b => tail_invalid_digit(b),
243    }
244}
245
246#[inline]
247fn byte_dispatch_u64<
248    const POINT: bool,
249    const NEG: bool,
250    const HAS: bool,
251    const BIG: bool,
252    const FIRST: bool,
253    const ROUND: bool,
254>(
255    bytes: &[u8],
256    data64: u64,
257    scale: u8,
258    b: u8,
259) -> Result<Decimal, Error> {
260    match b {
261        b'0'..=b'9' => handle_digit_64::<POINT, NEG, BIG, ROUND>(bytes, data64, scale, b - b'0'),
262        b'.' if !POINT => handle_point::<NEG, HAS, BIG, ROUND>(bytes, data64, scale),
263        b => non_digit_dispatch_u64::<POINT, NEG, HAS, BIG, FIRST, ROUND>(bytes, data64, scale, b),
264    }
265}
266
267#[inline(never)]
268fn handle_digit_64<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>(
269    bytes: &[u8],
270    data64: u64,
271    scale: u8,
272    digit: u8,
273) -> Result<Decimal, Error> {
274    // we have already validated that we cannot overflow
275    let data64 = data64 * 10 + digit as u64;
276    let scale = if POINT { scale + 1 } else { 0 };
277
278    if let Some((next, bytes)) = bytes.split_first() {
279        let next = *next;
280        if POINT && BIG && scale >= 28 {
281            if ROUND {
282                maybe_round(data64 as u128, next, scale, POINT, NEG)
283            } else {
284                Err(Error::Underflow)
285            }
286        } else if BIG && overflow_64(data64) {
287            handle_full_128::<POINT, NEG, ROUND>(data64 as u128, bytes, scale, next)
288        } else {
289            byte_dispatch_u64::<POINT, NEG, true, BIG, false, ROUND>(bytes, data64, scale, next)
290        }
291    } else {
292        let data: u128 = data64 as u128;
293
294        handle_data::<NEG, true>(data, scale)
295    }
296}
297
298#[inline(never)]
299fn handle_point<const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>(
300    bytes: &[u8],
301    data64: u64,
302    scale: u8,
303) -> Result<Decimal, Error> {
304    dispatch_next::<true, NEG, HAS, BIG, ROUND>(bytes, data64, scale)
305}
306
307#[inline(never)]
308fn handle_separator<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>(
309    bytes: &[u8],
310    data64: u64,
311    scale: u8,
312) -> Result<Decimal, Error> {
313    dispatch_next::<POINT, NEG, true, BIG, ROUND>(bytes, data64, scale)
314}
315
316#[inline(never)]
317#[cold]
318fn tail_invalid_digit(digit: u8) -> Result<Decimal, Error> {
319    match digit {
320        b'.' => tail_error("Invalid decimal: two decimal points"),
321        b'_' => tail_error("Invalid decimal: must start lead with a number"),
322        _ => tail_error("Invalid decimal: unknown character"),
323    }
324}
325
326#[inline(never)]
327#[cold]
328fn handle_full_128<const POINT: bool, const NEG: bool, const ROUND: bool>(
329    mut data: u128,
330    bytes: &[u8],
331    scale: u8,
332    next_byte: u8,
333) -> Result<Decimal, Error> {
334    let b = next_byte;
335    match b {
336        b'0'..=b'9' => {
337            let digit = u32::from(b - b'0');
338
339            // If the data is going to overflow then we should go into recovery mode
340            let next = (data * 10) + digit as u128;
341            if overflow_128(next) {
342                if !POINT {
343                    return tail_error("Invalid decimal: overflow from too many digits");
344                }
345
346                if ROUND {
347                    maybe_round(data, next_byte, scale, POINT, NEG)
348                } else {
349                    Err(Error::Underflow)
350                }
351            } else {
352                data = next;
353                let scale = scale + POINT as u8;
354                if let Some((next, bytes)) = bytes.split_first() {
355                    let next = *next;
356                    if POINT && scale >= 28 {
357                        if ROUND {
358                            // If it is an underscore at the rounding position we require slightly different handling to look ahead another digit
359                            if next == b'_' {
360                                if let Some((next, bytes)) = bytes.split_first() {
361                                    handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, *next)
362                                } else {
363                                    handle_data::<NEG, true>(data, scale)
364                                }
365                            } else {
366                                // Otherwise, we round as usual
367                                maybe_round(data, next, scale, POINT, NEG)
368                            }
369                        } else {
370                            Err(Error::Underflow)
371                        }
372                    } else {
373                        handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, next)
374                    }
375                } else {
376                    handle_data::<NEG, true>(data, scale)
377                }
378            }
379        }
380        b'.' if !POINT => {
381            // This call won't tail?
382            if let Some((next, bytes)) = bytes.split_first() {
383                handle_full_128::<true, NEG, ROUND>(data, bytes, scale, *next)
384            } else {
385                handle_data::<NEG, true>(data, scale)
386            }
387        }
388        b'_' => {
389            if let Some((next, bytes)) = bytes.split_first() {
390                handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, *next)
391            } else {
392                handle_data::<NEG, true>(data, scale)
393            }
394        }
395        b => tail_invalid_digit(b),
396    }
397}
398
399#[inline(never)]
400#[cold]
401fn maybe_round(mut data: u128, next_byte: u8, mut scale: u8, point: bool, negative: bool) -> Result<Decimal, Error> {
402    let digit = match next_byte {
403        b'0'..=b'9' => u32::from(next_byte - b'0'),
404        b'_' => 0, // This is perhaps an error case, but keep this here for compatibility
405        b'.' if !point => 0,
406        b => return tail_invalid_digit(b),
407    };
408
409    // Round at midpoint
410    if digit >= 5 {
411        data += 1;
412
413        // If the mantissa is now overflowing, round to the next
414        // next least significant digit and discard precision
415        if overflow_128(data) {
416            if scale == 0 {
417                return tail_error("Invalid decimal: overflow from mantissa after rounding");
418            }
419            data += 4;
420            data /= 10;
421            scale -= 1;
422        }
423    }
424
425    if negative {
426        handle_data::<true, true>(data, scale)
427    } else {
428        handle_data::<false, true>(data, scale)
429    }
430}
431
432#[inline(never)]
433fn tail_no_has() -> Result<Decimal, Error> {
434    tail_error("Invalid decimal: no digits found")
435}
436
437#[inline]
438fn handle_data<const NEG: bool, const HAS: bool>(data: u128, scale: u8) -> Result<Decimal, Error> {
439    debug_assert_eq!(data >> 96, 0);
440    if !HAS {
441        tail_no_has()
442    } else {
443        Ok(Decimal::from_parts(
444            data as u32,
445            (data >> 32) as u32,
446            (data >> 64) as u32,
447            NEG,
448            scale as u32,
449        ))
450    }
451}
452
453pub(crate) fn parse_str_radix_n(str: &str, radix: u32) -> Result<Decimal, Error> {
454    if str.is_empty() {
455        return Err(Error::from("Invalid decimal: empty"));
456    }
457    if radix < 2 {
458        return Err(Error::from("Unsupported radix < 2"));
459    }
460    if radix > 36 {
461        // As per trait documentation
462        return Err(Error::from("Unsupported radix > 36"));
463    }
464
465    let mut offset = 0;
466    let mut len = str.len();
467    let bytes = str.as_bytes();
468    let mut negative = false; // assume positive
469
470    // handle the sign
471    if bytes[offset] == b'-' {
472        negative = true; // leading minus means negative
473        offset += 1;
474        len -= 1;
475    } else if bytes[offset] == b'+' {
476        // leading + allowed
477        offset += 1;
478        len -= 1;
479    }
480
481    // should now be at numeric part of the significand
482    let mut digits_before_dot: i32 = -1; // digits before '.', -1 if no '.'
483    let mut coeff = ArrayVec::<_, 96>::new(); // integer significand array
484
485    // Supporting different radix
486    let (max_n, max_alpha_lower, max_alpha_upper) = if radix <= 10 {
487        (b'0' + (radix - 1) as u8, 0, 0)
488    } else {
489        let adj = (radix - 11) as u8;
490        (b'9', adj + b'a', adj + b'A')
491    };
492
493    // Estimate the max precision. All in all, it needs to fit into 96 bits.
494    // Rather than try to estimate, I've included the constants directly in here. We could,
495    // perhaps, replace this with a formula if it's faster - though it does appear to be log2.
496    let estimated_max_precision = match radix {
497        2 => 96,
498        3 => 61,
499        4 => 48,
500        5 => 42,
501        6 => 38,
502        7 => 35,
503        8 => 32,
504        9 => 31,
505        10 => 28,
506        11 => 28,
507        12 => 27,
508        13 => 26,
509        14 => 26,
510        15 => 25,
511        16 => 24,
512        17 => 24,
513        18 => 24,
514        19 => 23,
515        20 => 23,
516        21 => 22,
517        22 => 22,
518        23 => 22,
519        24 => 21,
520        25 => 21,
521        26 => 21,
522        27 => 21,
523        28 => 20,
524        29 => 20,
525        30 => 20,
526        31 => 20,
527        32 => 20,
528        33 => 20,
529        34 => 19,
530        35 => 19,
531        36 => 19,
532        _ => return Err(Error::from("Unsupported radix")),
533    };
534
535    let mut maybe_round = false;
536    while len > 0 {
537        let b = bytes[offset];
538        match b {
539            b'0'..=b'9' => {
540                if b > max_n {
541                    return Err(Error::from("Invalid decimal: invalid character"));
542                }
543                coeff.push(u32::from(b - b'0'));
544                offset += 1;
545                len -= 1;
546
547                // If the coefficient is longer than the max, exit early
548                if coeff.len() as u32 > estimated_max_precision {
549                    maybe_round = true;
550                    break;
551                }
552            }
553            b'a'..=b'z' => {
554                if b > max_alpha_lower {
555                    return Err(Error::from("Invalid decimal: invalid character"));
556                }
557                coeff.push(u32::from(b - b'a') + 10);
558                offset += 1;
559                len -= 1;
560
561                if coeff.len() as u32 > estimated_max_precision {
562                    maybe_round = true;
563                    break;
564                }
565            }
566            b'A'..=b'Z' => {
567                if b > max_alpha_upper {
568                    return Err(Error::from("Invalid decimal: invalid character"));
569                }
570                coeff.push(u32::from(b - b'A') + 10);
571                offset += 1;
572                len -= 1;
573
574                if coeff.len() as u32 > estimated_max_precision {
575                    maybe_round = true;
576                    break;
577                }
578            }
579            b'.' => {
580                if digits_before_dot >= 0 {
581                    return Err(Error::from("Invalid decimal: two decimal points"));
582                }
583                digits_before_dot = coeff.len() as i32;
584                offset += 1;
585                len -= 1;
586            }
587            b'_' => {
588                // Must start with a number...
589                if coeff.is_empty() {
590                    return Err(Error::from("Invalid decimal: must start lead with a number"));
591                }
592                offset += 1;
593                len -= 1;
594            }
595            _ => return Err(Error::from("Invalid decimal: unknown character")),
596        }
597    }
598
599    // If we exited before the end of the string then do some rounding if necessary
600    if maybe_round && offset < bytes.len() {
601        let next_byte = bytes[offset];
602        let digit = match next_byte {
603            b'0'..=b'9' => {
604                if next_byte > max_n {
605                    return Err(Error::from("Invalid decimal: invalid character"));
606                }
607                u32::from(next_byte - b'0')
608            }
609            b'a'..=b'z' => {
610                if next_byte > max_alpha_lower {
611                    return Err(Error::from("Invalid decimal: invalid character"));
612                }
613                u32::from(next_byte - b'a') + 10
614            }
615            b'A'..=b'Z' => {
616                if next_byte > max_alpha_upper {
617                    return Err(Error::from("Invalid decimal: invalid character"));
618                }
619                u32::from(next_byte - b'A') + 10
620            }
621            b'_' => 0,
622            b'.' => {
623                // Still an error if we have a second dp
624                if digits_before_dot >= 0 {
625                    return Err(Error::from("Invalid decimal: two decimal points"));
626                }
627                0
628            }
629            _ => return Err(Error::from("Invalid decimal: unknown character")),
630        };
631
632        // Round at midpoint
633        let midpoint = if radix & 0x1 == 1 { radix / 2 } else { (radix + 1) / 2 };
634        if digit >= midpoint {
635            let mut index = coeff.len() - 1;
636            loop {
637                let new_digit = coeff[index] + 1;
638                if new_digit <= 9 {
639                    coeff[index] = new_digit;
640                    break;
641                } else {
642                    coeff[index] = 0;
643                    if index == 0 {
644                        coeff.insert(0, 1u32);
645                        digits_before_dot += 1;
646                        coeff.pop();
647                        break;
648                    }
649                }
650                index -= 1;
651            }
652        }
653    }
654
655    // here when no characters left
656    if coeff.is_empty() {
657        return Err(Error::from("Invalid decimal: no digits found"));
658    }
659
660    let mut scale = if digits_before_dot >= 0 {
661        // we had a decimal place so set the scale
662        (coeff.len() as u32) - (digits_before_dot as u32)
663    } else {
664        0
665    };
666
667    // Parse this using specified radix
668    let mut data = [0u32, 0u32, 0u32];
669    let mut tmp = [0u32, 0u32, 0u32];
670    let len = coeff.len();
671    for (i, digit) in coeff.iter().enumerate() {
672        // If the data is going to overflow then we should go into recovery mode
673        tmp[0] = data[0];
674        tmp[1] = data[1];
675        tmp[2] = data[2];
676        let overflow = mul_by_u32(&mut tmp, radix);
677        if overflow > 0 {
678            // This means that we have more data to process, that we're not sure what to do with.
679            // This may or may not be an issue - depending on whether we're past a decimal point
680            // or not.
681            if (i as i32) < digits_before_dot && i + 1 < len {
682                return Err(Error::from("Invalid decimal: overflow from too many digits"));
683            }
684
685            if *digit >= 5 {
686                let carry = add_one_internal(&mut data);
687                if carry > 0 {
688                    // Highly unlikely scenario which is more indicative of a bug
689                    return Err(Error::from("Invalid decimal: overflow when rounding"));
690                }
691            }
692            // We're also one less digit so reduce the scale
693            let diff = (len - i) as u32;
694            if diff > scale {
695                return Err(Error::from("Invalid decimal: overflow from scale mismatch"));
696            }
697            scale -= diff;
698            break;
699        } else {
700            data[0] = tmp[0];
701            data[1] = tmp[1];
702            data[2] = tmp[2];
703            let carry = add_by_internal_flattened(&mut data, *digit);
704            if carry > 0 {
705                // Highly unlikely scenario which is more indicative of a bug
706                return Err(Error::from("Invalid decimal: overflow from carry"));
707            }
708        }
709    }
710
711    Ok(Decimal::from_parts(data[0], data[1], data[2], negative, scale))
712}
713
714#[cfg(test)]
715mod test {
716    use super::*;
717    use crate::Decimal;
718    use arrayvec::ArrayString;
719    use core::{fmt::Write, str::FromStr};
720
721    #[test]
722    fn display_does_not_overflow_max_capacity() {
723        let num = Decimal::from_str("1.2").unwrap();
724        let mut buffer = ArrayString::<64>::new();
725        buffer.write_fmt(format_args!("{num:.31}")).unwrap();
726        assert_eq!("1.2000000000000000000000000000000", buffer.as_str());
727    }
728
729    #[test]
730    fn from_str_rounding_0() {
731        assert_eq!(
732            parse_str_radix_10("1.234").unwrap().unpack(),
733            Decimal::new(1234, 3).unpack()
734        );
735    }
736
737    #[test]
738    fn from_str_rounding_1() {
739        assert_eq!(
740            parse_str_radix_10("11111_11111_11111.11111_11111_11111")
741                .unwrap()
742                .unpack(),
743            Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_111, 14).unpack()
744        );
745    }
746
747    #[test]
748    fn from_str_rounding_2() {
749        assert_eq!(
750            parse_str_radix_10("11111_11111_11111.11111_11111_11115")
751                .unwrap()
752                .unpack(),
753            Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_112, 14).unpack()
754        );
755    }
756
757    #[test]
758    fn from_str_rounding_3() {
759        assert_eq!(
760            parse_str_radix_10("11111_11111_11111.11111_11111_11195")
761                .unwrap()
762                .unpack(),
763            Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_1120, 14).unpack() // was Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_112, 13)
764        );
765    }
766
767    #[test]
768    fn from_str_rounding_4() {
769        assert_eq!(
770            parse_str_radix_10("99999_99999_99999.99999_99999_99995")
771                .unwrap()
772                .unpack(),
773            Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 13).unpack() // was Decimal::from_i128_with_scale(1_000_000_000_000_000_000_000_000_000, 12)
774        );
775    }
776
777    #[test]
778    fn from_str_no_rounding_0() {
779        assert_eq!(
780            parse_str_radix_10_exact("1.234").unwrap().unpack(),
781            Decimal::new(1234, 3).unpack()
782        );
783    }
784
785    #[test]
786    fn from_str_no_rounding_1() {
787        assert_eq!(
788            parse_str_radix_10_exact("11111_11111_11111.11111_11111_11111"),
789            Err(Error::Underflow)
790        );
791    }
792
793    #[test]
794    fn from_str_no_rounding_2() {
795        assert_eq!(
796            parse_str_radix_10_exact("11111_11111_11111.11111_11111_11115"),
797            Err(Error::Underflow)
798        );
799    }
800
801    #[test]
802    fn from_str_no_rounding_3() {
803        assert_eq!(
804            parse_str_radix_10_exact("11111_11111_11111.11111_11111_11195"),
805            Err(Error::Underflow)
806        );
807    }
808
809    #[test]
810    fn from_str_no_rounding_4() {
811        assert_eq!(
812            parse_str_radix_10_exact("99999_99999_99999.99999_99999_99995"),
813            Err(Error::Underflow)
814        );
815    }
816
817    #[test]
818    fn from_str_many_pointless_chars() {
819        assert_eq!(
820            parse_str_radix_10("00________________________________________________________________001.1")
821                .unwrap()
822                .unpack(),
823            Decimal::from_i128_with_scale(11, 1).unpack()
824        );
825    }
826
827    #[test]
828    fn from_str_leading_0s_1() {
829        assert_eq!(
830            parse_str_radix_10("00001.1").unwrap().unpack(),
831            Decimal::from_i128_with_scale(11, 1).unpack()
832        );
833    }
834
835    #[test]
836    fn from_str_leading_0s_2() {
837        assert_eq!(
838            parse_str_radix_10("00000_00000_00000_00000_00001.00001")
839                .unwrap()
840                .unpack(),
841            Decimal::from_i128_with_scale(100001, 5).unpack()
842        );
843    }
844
845    #[test]
846    fn from_str_leading_0s_3() {
847        assert_eq!(
848            parse_str_radix_10("0.00000_00000_00000_00000_00000_00100")
849                .unwrap()
850                .unpack(),
851            Decimal::from_i128_with_scale(1, 28).unpack()
852        );
853    }
854
855    #[test]
856    fn from_str_trailing_0s_1() {
857        assert_eq!(
858            parse_str_radix_10("0.00001_00000_00000").unwrap().unpack(),
859            Decimal::from_i128_with_scale(10_000_000_000, 15).unpack()
860        );
861    }
862
863    #[test]
864    fn from_str_trailing_0s_2() {
865        assert_eq!(
866            parse_str_radix_10("0.00001_00000_00000_00000_00000_00000")
867                .unwrap()
868                .unpack(),
869            Decimal::from_i128_with_scale(100_000_000_000_000_000_000_000, 28).unpack()
870        );
871    }
872
873    #[test]
874    fn from_str_overflow_1() {
875        assert_eq!(
876            parse_str_radix_10("99999_99999_99999_99999_99999_99999.99999"),
877            // The original implementation returned
878            //              Ok(10000_00000_00000_00000_00000_0000)
879            // Which is a bug!
880            Err(Error::from("Invalid decimal: overflow from too many digits"))
881        );
882    }
883
884    #[test]
885    fn from_str_overflow_2() {
886        assert!(
887            parse_str_radix_10("99999_99999_99999_99999_99999_11111.11111").is_err(),
888            // The original implementation is 'overflow from scale mismatch'
889            // but we got rid of that now
890        );
891    }
892
893    #[test]
894    fn from_str_overflow_3() {
895        assert!(
896            parse_str_radix_10("99999_99999_99999_99999_99999_99994").is_err() // We could not get into 'overflow when rounding' or 'overflow from carry'
897                                                                               // in the original implementation because the rounding logic before prevented it
898        );
899    }
900
901    #[test]
902    fn from_str_overflow_4() {
903        assert_eq!(
904            // This does not overflow, moving the decimal point 1 more step would result in
905            // 'overflow from too many digits'
906            parse_str_radix_10("99999_99999_99999_99999_99999_999.99")
907                .unwrap()
908                .unpack(),
909            Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 0).unpack()
910        );
911    }
912
913    #[test]
914    fn from_str_mantissa_overflow_1() {
915        // reminder:
916        assert_eq!(OVERFLOW_U96, 79_228_162_514_264_337_593_543_950_336);
917        assert_eq!(
918            parse_str_radix_10("79_228_162_514_264_337_593_543_950_33.56")
919                .unwrap()
920                .unpack(),
921            Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 0).unpack()
922        );
923        // This is a mantissa of OVERFLOW_U96 - 1 just before reaching the last digit.
924        // Previously, this would return Err("overflow from mantissa after rounding")
925        // instead of successfully rounding.
926    }
927
928    #[test]
929    fn from_str_mantissa_overflow_2() {
930        assert_eq!(
931            parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.6"),
932            Err(Error::from("Invalid decimal: overflow from mantissa after rounding"))
933        );
934        // this case wants to round to 79_228_162_514_264_337_593_543_950_340.
935        // (79_228_162_514_264_337_593_543_950_336 is OVERFLOW_U96 and too large
936        // to fit in 96 bits) which is also too large for the mantissa so fails.
937    }
938
939    #[test]
940    fn from_str_mantissa_overflow_3() {
941        // this hits the other avoidable overflow case in maybe_round
942        assert_eq!(
943            parse_str_radix_10("7.92281625142643375935439503356").unwrap().unpack(),
944            Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack()
945        );
946    }
947
948    #[test]
949    fn from_str_mantissa_overflow_4() {
950        // Same test as above, however with underscores. This causes issues.
951        assert_eq!(
952            parse_str_radix_10("7.9_228_162_514_264_337_593_543_950_335_6")
953                .unwrap()
954                .unpack(),
955            Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack()
956        );
957    }
958
959    #[test]
960    fn invalid_input_1() {
961        assert_eq!(
962            parse_str_radix_10("1.0000000000000000000000000000.5"),
963            Err(Error::from("Invalid decimal: two decimal points"))
964        );
965    }
966
967    #[test]
968    fn invalid_input_2() {
969        assert_eq!(
970            parse_str_radix_10("1.0.5"),
971            Err(Error::from("Invalid decimal: two decimal points"))
972        );
973    }
974
975    #[test]
976    fn character_at_rounding_position() {
977        let tests = [
978            // digit is at the rounding position
979            (
980                "1.000_000_000_000_000_000_000_000_000_04",
981                Ok(Decimal::from_i128_with_scale(
982                    1_000_000_000_000_000_000_000_000_000_0,
983                    28,
984                )),
985            ),
986            (
987                "1.000_000_000_000_000_000_000_000_000_06",
988                Ok(Decimal::from_i128_with_scale(
989                    1_000_000_000_000_000_000_000_000_000_1,
990                    28,
991                )),
992            ),
993            // Decimal point is at the rounding position
994            (
995                "1_000_000_000_000_000_000_000_000_000_0.4",
996                Ok(Decimal::from_i128_with_scale(
997                    1_000_000_000_000_000_000_000_000_000_0,
998                    0,
999                )),
1000            ),
1001            (
1002                "1_000_000_000_000_000_000_000_000_000_0.6",
1003                Ok(Decimal::from_i128_with_scale(
1004                    1_000_000_000_000_000_000_000_000_000_1,
1005                    0,
1006                )),
1007            ),
1008            // Placeholder is at the rounding position
1009            (
1010                "1.000_000_000_000_000_000_000_000_000_0_4",
1011                Ok(Decimal::from_i128_with_scale(
1012                    1_000_000_000_000_000_000_000_000_000_0,
1013                    28,
1014                )),
1015            ),
1016            (
1017                "1.000_000_000_000_000_000_000_000_000_0_6",
1018                Ok(Decimal::from_i128_with_scale(
1019                    1_000_000_000_000_000_000_000_000_000_1,
1020                    28,
1021                )),
1022            ),
1023            // Multiple placeholders at rounding position
1024            (
1025                "1.000_000_000_000_000_000_000_000_000_0__4",
1026                Ok(Decimal::from_i128_with_scale(
1027                    1_000_000_000_000_000_000_000_000_000_0,
1028                    28,
1029                )),
1030            ),
1031            (
1032                "1.000_000_000_000_000_000_000_000_000_0__6",
1033                Ok(Decimal::from_i128_with_scale(
1034                    1_000_000_000_000_000_000_000_000_000_1,
1035                    28,
1036                )),
1037            ),
1038        ];
1039
1040        for (input, expected) in tests.iter() {
1041            assert_eq!(parse_str_radix_10(input), *expected, "Test input {}", input);
1042        }
1043    }
1044
1045    #[test]
1046    fn from_str_edge_cases_1() {
1047        assert_eq!(parse_str_radix_10(""), Err(Error::from("Invalid decimal: empty")));
1048    }
1049
1050    #[test]
1051    fn from_str_edge_cases_2() {
1052        assert_eq!(
1053            parse_str_radix_10("0.1."),
1054            Err(Error::from("Invalid decimal: two decimal points"))
1055        );
1056    }
1057
1058    #[test]
1059    fn from_str_edge_cases_3() {
1060        assert_eq!(
1061            parse_str_radix_10("_"),
1062            Err(Error::from("Invalid decimal: must start lead with a number"))
1063        );
1064    }
1065
1066    #[test]
1067    fn from_str_edge_cases_4() {
1068        assert_eq!(
1069            parse_str_radix_10("1?2"),
1070            Err(Error::from("Invalid decimal: unknown character"))
1071        );
1072    }
1073
1074    #[test]
1075    fn from_str_edge_cases_5() {
1076        assert_eq!(
1077            parse_str_radix_10("."),
1078            Err(Error::from("Invalid decimal: no digits found"))
1079        );
1080    }
1081
1082    #[test]
1083    fn from_str_edge_cases_6() {
1084        // Decimal::MAX + 0.99999
1085        assert_eq!(
1086            parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.99999"),
1087            Err(Error::from("Invalid decimal: overflow from mantissa after rounding"))
1088        );
1089    }
1090}