fix(serde): support enum deserialization

This commit is contained in:
David Knaack 2022-04-28 22:12:27 +02:00
parent c5a783bd71
commit 68b4faf29d

View File

@ -1,5 +1,6 @@
use crate::module::ALL_MODULES; use crate::module::ALL_MODULES;
use serde::de::{ use serde::de::{
self,
value::{Error as ValueError, MapDeserializer, SeqDeserializer}, value::{Error as ValueError, MapDeserializer, SeqDeserializer},
Deserializer, Error, IntoDeserializer, Visitor, Deserializer, Error, IntoDeserializer, Visitor,
}; };
@ -53,6 +54,20 @@ impl ValueDeserializer<'_> {
_ => ValueError::custom(msg), _ => ValueError::custom(msg),
} }
} }
/// Return the fitting `de::Unexpected` type description for the given value.
/// For use with `Error::invalid_type`.
fn serde_unexpected(&self) -> de::Unexpected {
match self.value {
Value::Boolean(b) => de::Unexpected::Bool(*b),
Value::Integer(i) => de::Unexpected::Signed(*i),
Value::Float(f) => de::Unexpected::Float(*f),
Value::String(ref s) => de::Unexpected::Str(s),
Value::Array(_v) => de::Unexpected::Other("array"),
Value::Table(_v) => de::Unexpected::Other("table"),
Value::Datetime(_v) => de::Unexpected::Other("datetime"),
}
}
} }
impl<'de> IntoDeserializer<'de> for ValueDeserializer<'de> { impl<'de> IntoDeserializer<'de> for ValueDeserializer<'de> {
@ -161,10 +176,29 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
visitor.visit_newtype_struct(self) visitor.visit_newtype_struct(self)
} }
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.value {
// de::Value::StrDeserializer implements de::EnumAccess, so we can just use it.
Value::String(s) => visitor.visit_enum(s.as_str().into_deserializer()),
_ => Err(Self::Error::invalid_type(
self.serde_unexpected(),
&"string",
)),
}
}
// Handle most deserialization cases by deferring to `deserialize_any`. // Handle most deserialization cases by deferring to `deserialize_any`.
serde::forward_to_deserialize_any! { serde::forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq
bytes byte_buf map unit_struct tuple_struct enum tuple identifier bytes byte_buf map unit_struct tuple_struct tuple identifier
} }
} }
@ -296,6 +330,72 @@ mod test {
assert_eq!(result.foo.0, "bar".to_owned()); assert_eq!(result.foo.0, "bar".to_owned());
} }
#[derive(Deserialize, PartialEq, Debug)]
#[serde(rename_all = "snake_case")]
enum SampleEnum {
FirstItem,
Second,
ThirdItem,
}
#[test]
fn test_deserialize_enum() {
#[derive(Deserialize)]
struct Sample {
foo: SampleEnum,
}
let value = toml::toml! {
foo = "first_item"
};
let deserializer = ValueDeserializer::new(&value);
let result = Sample::deserialize(deserializer).unwrap();
assert_eq!(result.foo, SampleEnum::FirstItem);
}
#[test]
fn test_deserialize_enum_unknown() {
#[derive(Deserialize)]
#[allow(dead_code)]
struct Sample {
foo: SampleEnum,
}
let value = toml::toml! {
foo = "unknown"
};
let deserializer = ValueDeserializer::new(&value);
let result = Sample::deserialize(deserializer);
assert!(result.is_err());
}
#[test]
fn test_deserialize_enum_invalid_type() {
#[derive(Deserialize, PartialEq, Debug)]
#[allow(dead_code)]
struct Sample {
foo: SampleEnum,
}
let value = toml::toml! {
foo = 1
};
let deserializer = ValueDeserializer::new(&value);
let result = Sample::deserialize(deserializer);
assert_eq!(
result,
Err(serde::de::Error::custom(
"invalid type: integer `1`, expected string"
))
);
}
#[test] #[test]
fn test_deserialize_unknown() { fn test_deserialize_unknown() {
let value = toml::toml! { let value = toml::toml! {