From 68b4faf29df39e372b5e96b5f8c23a9e7b95520f Mon Sep 17 00:00:00 2001 From: David Knaack Date: Thu, 28 Apr 2022 22:12:27 +0200 Subject: [PATCH] fix(serde): support `enum` deserialization --- src/serde_utils.rs | 102 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/src/serde_utils.rs b/src/serde_utils.rs index 62b9ab6f1..51fbb08fb 100644 --- a/src/serde_utils.rs +++ b/src/serde_utils.rs @@ -1,5 +1,6 @@ use crate::module::ALL_MODULES; use serde::de::{ + self, value::{Error as ValueError, MapDeserializer, SeqDeserializer}, Deserializer, Error, IntoDeserializer, Visitor, }; @@ -53,6 +54,20 @@ impl ValueDeserializer<'_> { _ => ValueError::custom(msg), } } + + /// Return the fitting `de::Unexpected` type description for the given value. + /// For use with `Error::invalid_type`. + fn serde_unexpected(&self) -> de::Unexpected { + match self.value { + Value::Boolean(b) => de::Unexpected::Bool(*b), + Value::Integer(i) => de::Unexpected::Signed(*i), + Value::Float(f) => de::Unexpected::Float(*f), + Value::String(ref s) => de::Unexpected::Str(s), + Value::Array(_v) => de::Unexpected::Other("array"), + Value::Table(_v) => de::Unexpected::Other("table"), + Value::Datetime(_v) => de::Unexpected::Other("datetime"), + } + } } impl<'de> IntoDeserializer<'de> for ValueDeserializer<'de> { @@ -161,10 +176,29 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { visitor.visit_newtype_struct(self) } + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + match self.value { + // de::Value::StrDeserializer implements de::EnumAccess, so we can just use it. + Value::String(s) => visitor.visit_enum(s.as_str().into_deserializer()), + _ => Err(Self::Error::invalid_type( + self.serde_unexpected(), + &"string", + )), + } + } + // Handle most deserialization cases by deferring to `deserialize_any`. serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq - bytes byte_buf map unit_struct tuple_struct enum tuple identifier + bytes byte_buf map unit_struct tuple_struct tuple identifier } } @@ -296,6 +330,72 @@ mod test { assert_eq!(result.foo.0, "bar".to_owned()); } + #[derive(Deserialize, PartialEq, Debug)] + #[serde(rename_all = "snake_case")] + enum SampleEnum { + FirstItem, + Second, + ThirdItem, + } + + #[test] + fn test_deserialize_enum() { + #[derive(Deserialize)] + struct Sample { + foo: SampleEnum, + } + + let value = toml::toml! { + foo = "first_item" + }; + + let deserializer = ValueDeserializer::new(&value); + + let result = Sample::deserialize(deserializer).unwrap(); + assert_eq!(result.foo, SampleEnum::FirstItem); + } + + #[test] + fn test_deserialize_enum_unknown() { + #[derive(Deserialize)] + #[allow(dead_code)] + struct Sample { + foo: SampleEnum, + } + + let value = toml::toml! { + foo = "unknown" + }; + + let deserializer = ValueDeserializer::new(&value); + + let result = Sample::deserialize(deserializer); + assert!(result.is_err()); + } + + #[test] + fn test_deserialize_enum_invalid_type() { + #[derive(Deserialize, PartialEq, Debug)] + #[allow(dead_code)] + struct Sample { + foo: SampleEnum, + } + + let value = toml::toml! { + foo = 1 + }; + + let deserializer = ValueDeserializer::new(&value); + + let result = Sample::deserialize(deserializer); + assert_eq!( + result, + Err(serde::de::Error::custom( + "invalid type: integer `1`, expected string" + )) + ); + } + #[test] fn test_deserialize_unknown() { let value = toml::toml! {