diff --git a/crates/nu-derive-value/src/from.rs b/crates/nu-derive-value/src/from.rs index 033026c149..783a22920e 100644 --- a/crates/nu-derive-value/src/from.rs +++ b/crates/nu-derive-value/src/from.rs @@ -3,6 +3,7 @@ use proc_macro2::TokenStream as TokenStream2; use quote::{quote, ToTokens}; use syn::{ spanned::Spanned, Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Generics, Ident, + Type, }; use crate::attributes::{self, ContainerAttributes}; @@ -116,15 +117,11 @@ fn derive_struct_from_value( /// src_span: span /// })?, /// )?, -/// favorite_toy: as nu_protocol::FromValue>::from_value( -/// record -/// .remove("favorite_toy") -/// .ok_or_else(|| nu_protocol::ShellError::CantFindColumn { -/// col_name: std::string::ToString::to_string("favorite_toy"), -/// span: std::option::Option::None, -/// src_span: span -/// })?, -/// )?, +/// favorite_toy: record +/// .remove("favorite_toy") +/// .map(|v| <#ty as nu_protocol::FromValue>::from_value(v)) +/// .transpose()? +/// .flatten(), /// }) /// } /// } @@ -480,20 +477,29 @@ fn parse_value_via_fields(fields: &Fields, self_ident: impl ToTokens) -> TokenSt match fields { Fields::Named(fields) => { let fields = fields.named.iter().map(|field| { - // TODO: handle missing fields for Options as None let ident = field.ident.as_ref().expect("named has idents"); let ident_s = ident.to_string(); let ty = &field.ty; - quote! { - #ident: <#ty as nu_protocol::FromValue>::from_value( - record + match type_is_option(ty) { + true => quote! { + #ident: record .remove(#ident_s) - .ok_or_else(|| nu_protocol::ShellError::CantFindColumn { - col_name: std::string::ToString::to_string(#ident_s), - span: std::option::Option::None, - src_span: span - })?, - )? + .map(|v| <#ty as nu_protocol::FromValue>::from_value(v)) + .transpose()? + .flatten() + }, + + false => quote! { + #ident: <#ty as nu_protocol::FromValue>::from_value( + record + .remove(#ident_s) + .ok_or_else(|| nu_protocol::ShellError::CantFindColumn { + col_name: std::string::ToString::to_string(#ident_s), + span: std::option::Option::None, + src_span: span + })?, + )? + }, } }); quote! { @@ -537,3 +543,25 @@ fn parse_value_via_fields(fields: &Fields, self_ident: impl ToTokens) -> TokenSt }, } } + +const FULLY_QUALIFIED_OPTION: &str = "std::option::Option"; +const PARTIALLY_QUALIFIED_OPTION: &str = "option::Option"; +const PRELUDE_OPTION: &str = "Option"; + +/// Check if the field type is an `Option`. +/// +/// This function checks if a given type is an `Option`. +/// We assume that an `Option` is [`std::option::Option`] because we can't see the whole code and +/// can't ask the compiler itself. +/// If the `Option` type isn't `std::option::Option`, the user will get a compile error due to a +/// type mismatch. +/// It's very unusual for people to override `Option`, so this should rarely be an issue. +/// +/// When [rust#63084](https://github.com/rust-lang/rust/issues/63084) is resolved, we can use +/// [`std::any::type_name`] for a static assertion check to get a more direct error messages. +fn type_is_option(ty: &Type) -> bool { + let s = ty.to_token_stream().to_string(); + s.starts_with(PRELUDE_OPTION) + || s.starts_with(PARTIALLY_QUALIFIED_OPTION) + || s.starts_with(FULLY_QUALIFIED_OPTION) +} diff --git a/crates/nu-protocol/src/value/test_derive.rs b/crates/nu-protocol/src/value/test_derive.rs index 1865a05418..56bccabd2a 100644 --- a/crates/nu-protocol/src/value/test_derive.rs +++ b/crates/nu-protocol/src/value/test_derive.rs @@ -171,6 +171,62 @@ fn named_fields_struct_incorrect_type() { assert!(res.is_err()); } +#[derive(IntoValue, FromValue, Debug, PartialEq, Default)] +struct ALotOfOptions { + required: bool, + float: Option, + int: Option, + value: Option, + nested: Option, +} + +#[test] +fn missing_options() { + let value = Value::test_record(Record::new()); + let res: Result = ALotOfOptions::from_value(value); + assert!(res.is_err()); + + let value = Value::test_record(record! {"required" => Value::test_bool(true)}); + let expected = ALotOfOptions { + required: true, + ..Default::default() + }; + let actual = ALotOfOptions::from_value(value).unwrap(); + assert_eq!(expected, actual); + + let value = Value::test_record(record! { + "required" => Value::test_bool(true), + "float" => Value::test_float(std::f64::consts::PI), + }); + let expected = ALotOfOptions { + required: true, + float: Some(std::f64::consts::PI), + ..Default::default() + }; + let actual = ALotOfOptions::from_value(value).unwrap(); + assert_eq!(expected, actual); + + let value = Value::test_record(record! { + "required" => Value::test_bool(true), + "int" => Value::test_int(12), + "nested" => Value::test_record(record! { + "u32" => Value::test_int(34), + }), + }); + let expected = ALotOfOptions { + required: true, + int: Some(12), + nested: Some(Nestee { + u32: 34, + some: None, + none: None, + }), + ..Default::default() + }; + let actual = ALotOfOptions::from_value(value).unwrap(); + assert_eq!(expected, actual); +} + #[derive(IntoValue, FromValue, Debug, PartialEq)] struct UnnamedFieldsStruct(u32, String, T) where