fix(serde): support enum deserialization

This commit is contained in:
David Knaack 2022-04-28 22:12:27 +02:00
parent c5a783bd71
commit 68b4faf29d

View File

@ -1,5 +1,6 @@
use crate::module::ALL_MODULES;
use serde::de::{
self,
value::{Error as ValueError, MapDeserializer, SeqDeserializer},
Deserializer, Error, IntoDeserializer, Visitor,
};
@ -53,6 +54,20 @@ impl ValueDeserializer<'_> {
_ => ValueError::custom(msg),
}
}
/// Return the fitting `de::Unexpected` type description for the given value.
/// For use with `Error::invalid_type`.
fn serde_unexpected(&self) -> de::Unexpected {
match self.value {
Value::Boolean(b) => de::Unexpected::Bool(*b),
Value::Integer(i) => de::Unexpected::Signed(*i),
Value::Float(f) => de::Unexpected::Float(*f),
Value::String(ref s) => de::Unexpected::Str(s),
Value::Array(_v) => de::Unexpected::Other("array"),
Value::Table(_v) => de::Unexpected::Other("table"),
Value::Datetime(_v) => de::Unexpected::Other("datetime"),
}
}
}
impl<'de> IntoDeserializer<'de> for ValueDeserializer<'de> {
@ -161,10 +176,29 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
visitor.visit_newtype_struct(self)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.value {
// de::Value::StrDeserializer implements de::EnumAccess, so we can just use it.
Value::String(s) => visitor.visit_enum(s.as_str().into_deserializer()),
_ => Err(Self::Error::invalid_type(
self.serde_unexpected(),
&"string",
)),
}
}
// Handle most deserialization cases by deferring to `deserialize_any`.
serde::forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq
bytes byte_buf map unit_struct tuple_struct enum tuple identifier
bytes byte_buf map unit_struct tuple_struct tuple identifier
}
}
@ -296,6 +330,72 @@ mod test {
assert_eq!(result.foo.0, "bar".to_owned());
}
#[derive(Deserialize, PartialEq, Debug)]
#[serde(rename_all = "snake_case")]
enum SampleEnum {
FirstItem,
Second,
ThirdItem,
}
#[test]
fn test_deserialize_enum() {
#[derive(Deserialize)]
struct Sample {
foo: SampleEnum,
}
let value = toml::toml! {
foo = "first_item"
};
let deserializer = ValueDeserializer::new(&value);
let result = Sample::deserialize(deserializer).unwrap();
assert_eq!(result.foo, SampleEnum::FirstItem);
}
#[test]
fn test_deserialize_enum_unknown() {
#[derive(Deserialize)]
#[allow(dead_code)]
struct Sample {
foo: SampleEnum,
}
let value = toml::toml! {
foo = "unknown"
};
let deserializer = ValueDeserializer::new(&value);
let result = Sample::deserialize(deserializer);
assert!(result.is_err());
}
#[test]
fn test_deserialize_enum_invalid_type() {
#[derive(Deserialize, PartialEq, Debug)]
#[allow(dead_code)]
struct Sample {
foo: SampleEnum,
}
let value = toml::toml! {
foo = 1
};
let deserializer = ValueDeserializer::new(&value);
let result = Sample::deserialize(deserializer);
assert_eq!(
result,
Err(serde::de::Error::custom(
"invalid type: integer `1`, expected string"
))
);
}
#[test]
fn test_deserialize_unknown() {
let value = toml::toml! {