tinc/private/
float_with_non_finite.rs

1use core::fmt;
2use std::collections::{BTreeMap, HashMap};
3use std::fmt::Display;
4use std::marker::PhantomData;
5
6use num_traits::{Float, FromPrimitive, ToPrimitive};
7use serde::Serialize;
8use serde::de::Error;
9
10use super::{DeserializeContent, DeserializeHelper, Expected, Tracker, TrackerDeserializer, TrackerFor};
11
12pub struct FloatWithNonFinTracker<T>(PhantomData<T>);
13
14impl<T> fmt::Debug for FloatWithNonFinTracker<T> {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        write!(f, "FloatWithNonFinTracker<{}>", std::any::type_name::<T>())
17    }
18}
19
20impl<T> Default for FloatWithNonFinTracker<T> {
21    fn default() -> Self {
22        Self(PhantomData)
23    }
24}
25
26impl<T: Expected> Tracker for FloatWithNonFinTracker<T> {
27    type Target = T;
28
29    #[inline(always)]
30    fn allow_duplicates(&self) -> bool {
31        false
32    }
33}
34
35#[repr(transparent)]
36pub struct FloatWithNonFinite<T>(T);
37
38impl<T: Default> Default for FloatWithNonFinite<T> {
39    fn default() -> Self {
40        Self(Default::default())
41    }
42}
43
44impl<T: Expected> TrackerFor for FloatWithNonFinite<T> {
45    type Tracker = FloatWithNonFinTracker<T>;
46}
47
48impl<T> Expected for FloatWithNonFinite<T> {
49    fn expecting(formatter: &mut fmt::Formatter) -> fmt::Result {
50        write!(formatter, stringify!(T))
51    }
52}
53
54// Deserialization
55
56pub trait FloatWithNonFinDesHelper: Sized {
57    type Target;
58}
59
60impl FloatWithNonFinDesHelper for f32 {
61    type Target = FloatWithNonFinite<f32>;
62}
63
64impl FloatWithNonFinDesHelper for f64 {
65    type Target = FloatWithNonFinite<f64>;
66}
67
68impl<T: FloatWithNonFinDesHelper> FloatWithNonFinDesHelper for Option<T> {
69    type Target = Option<T::Target>;
70}
71
72impl<T: FloatWithNonFinDesHelper> FloatWithNonFinDesHelper for Vec<T> {
73    type Target = Vec<T::Target>;
74}
75
76impl<K, V: FloatWithNonFinDesHelper> FloatWithNonFinDesHelper for BTreeMap<K, V> {
77    type Target = BTreeMap<K, V::Target>;
78}
79
80impl<K, V: FloatWithNonFinDesHelper, S> FloatWithNonFinDesHelper for HashMap<K, V, S> {
81    type Target = HashMap<K, V::Target, S>;
82}
83
84impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, FloatWithNonFinTracker<T>>
85where
86    T: serde::Deserialize<'de> + Float + ToPrimitive + FromPrimitive,
87    FloatWithNonFinTracker<T>: Tracker<Target = T>,
88{
89    type Value = ();
90
91    fn deserialize<D>(self, de: D) -> Result<Self::Value, D::Error>
92    where
93        D: serde::Deserializer<'de>,
94    {
95        struct Visitor<T>(PhantomData<T>);
96
97        impl<T> Default for Visitor<T> {
98            fn default() -> Self {
99                Self(PhantomData)
100            }
101        }
102
103        macro_rules! visit_convert_to_float {
104            ($visitor_func:ident, $conv_func:ident, $ty:ident) => {
105                fn $visitor_func<E>(self, v: $ty) -> Result<Self::Value, E>
106                where
107                    E: Error,
108                {
109                    match T::$conv_func(v) {
110                        Some(v) => Ok(v),
111                        None => Err(E::custom(format!("unable to extract float-type from {}", v))),
112                    }
113                }
114            };
115        }
116
117        impl<'de, T> serde::de::Visitor<'de> for Visitor<T>
118        where
119            T: serde::Deserialize<'de> + Float + ToPrimitive + FromPrimitive,
120        {
121            type Value = T;
122
123            visit_convert_to_float!(visit_f32, from_f32, f32);
124
125            visit_convert_to_float!(visit_f64, from_f64, f64);
126
127            visit_convert_to_float!(visit_u8, from_u8, u8);
128
129            visit_convert_to_float!(visit_u16, from_u16, u16);
130
131            visit_convert_to_float!(visit_u32, from_u32, u32);
132
133            visit_convert_to_float!(visit_u64, from_u64, u64);
134
135            visit_convert_to_float!(visit_i8, from_i8, i8);
136
137            visit_convert_to_float!(visit_i16, from_i16, i16);
138
139            visit_convert_to_float!(visit_i32, from_i32, i32);
140
141            visit_convert_to_float!(visit_i64, from_i64, i64);
142
143            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
144                write!(formatter, stringify!(T))
145            }
146
147            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
148            where
149                E: Error,
150            {
151                match v {
152                    "Infinity" => Ok(T::infinity()),
153                    "-Infinity" => Ok(T::neg_infinity()),
154                    "NaN" => Ok(T::nan()),
155                    _ => Err(E::custom(format!("unrecognized floating string: {}", v))),
156                }
157            }
158
159            fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
160            where
161                E: Error,
162            {
163                self.visit_str(v)
164            }
165
166            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
167            where
168                E: Error,
169            {
170                self.visit_str(&v)
171            }
172        }
173
174        *self.value = de.deserialize_any(Visitor::default())?;
175        Ok(())
176    }
177}
178
179impl<'de, T> TrackerDeserializer<'de> for FloatWithNonFinTracker<T>
180where
181    T: serde::Deserialize<'de> + Float + FromPrimitive,
182    FloatWithNonFinTracker<T>: Tracker<Target = T>,
183{
184    fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
185    where
186        D: DeserializeContent<'de>,
187    {
188        deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
189    }
190}
191
192// Serialization
193
194impl<T: Float + FromPrimitive + Display> serde::Serialize for FloatWithNonFinite<T> {
195    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
196    where
197        S: serde::Serializer,
198    {
199        match (self.0.is_nan(), self.0.is_infinite(), self.0.is_sign_negative()) {
200            (true, _, _) => serializer.serialize_str("NaN"),
201            (false, true, true) => serializer.serialize_str("-Infinity"),
202            (false, true, false) => serializer.serialize_str("Infinity"),
203            _ => {
204                let converted = self
205                    .0
206                    .to_f64()
207                    .ok_or_else(|| serde::ser::Error::custom(format!("Failed to convert {} to f64", self.0)))?;
208                serializer.serialize_f64(converted)
209            }
210        }
211    }
212}
213
214/// # Safety
215/// This trait is marked as unsafe because the implementator
216/// must ensure that Helper has the same layout & memory representation as Self.
217unsafe trait FloatWithNonFinSerHelper: Sized {
218    type Helper: Sized;
219
220    fn cast(value: &Self) -> &Self::Helper {
221        // Safety: this is safe given that the `unsafe trait`'s precondition is held.
222        unsafe { &*(value as *const Self as *const Self::Helper) }
223    }
224}
225
226/// Safety: [`FloatWithNonFinite`] is `#[repr(transparent)]` for [`f32`].
227unsafe impl FloatWithNonFinSerHelper for f32 {
228    type Helper = FloatWithNonFinite<f32>;
229}
230
231/// Safety: [`FloatWithNonFinite`] is `#[repr(transparent)]` for [`f64`].
232unsafe impl FloatWithNonFinSerHelper for f64 {
233    type Helper = FloatWithNonFinite<f64>;
234}
235
236/// Safety: [`FloatWithNonFinite<T>`] is naturally same as [`FloatWithNonFinite<T>`].
237unsafe impl<T: Float + FromPrimitive> FloatWithNonFinSerHelper for FloatWithNonFinite<T> {
238    type Helper = FloatWithNonFinite<T>;
239}
240
241/// Safety: If `T` is a [`FloatWithNonFinSerHelper`] type, then `Option<T>` can be cast to `Option<T::Helper>`
242unsafe impl<T: FloatWithNonFinSerHelper> FloatWithNonFinSerHelper for Option<T> {
243    type Helper = Option<T::Helper>;
244}
245
246/// Safety: If `T` is a [`FloatWithNonFinSerHelper`] type, then `Vec<T>` can be cast to `Vec<T::Helper>`
247unsafe impl<T: FloatWithNonFinSerHelper> FloatWithNonFinSerHelper for Vec<T> {
248    type Helper = Vec<T::Helper>;
249}
250
251/// Safety: If `T` is a [`FloatWithNonFinSerHelper`] type, then `BTreeMap<K,V>` can be cast to `BTreeMap<K,V::Helper>`
252unsafe impl<K, V: FloatWithNonFinSerHelper> FloatWithNonFinSerHelper for BTreeMap<K, V> {
253    type Helper = BTreeMap<K, V::Helper>;
254}
255
256/// Safety: If `T` is a [`FloatWithNonFinSerHelper`] type, then `HashMap<K,V>` can be cast to `HashMap<K,V::Helper>`
257unsafe impl<K, V: FloatWithNonFinSerHelper, S> FloatWithNonFinSerHelper for HashMap<K, V, S> {
258    type Helper = HashMap<K, V::Helper, S>;
259}
260
261#[allow(private_bounds)]
262pub fn serialize_floats_with_non_finite<V, S>(value: &V, serializer: S) -> Result<S::Ok, S::Error>
263where
264    V: FloatWithNonFinSerHelper,
265    V::Helper: serde::Serialize,
266    S: serde::Serializer,
267{
268    V::cast(value).serialize(serializer)
269}