diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index e5ec8578d7..7353436fa5 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -344,6 +344,37 @@ pub fn parse_external_call( } } +fn ensure_flag_arg_type( + working_set: &mut StateWorkingSet, + arg_name: String, + arg: Expression, + arg_shape: &SyntaxShape, + long_name_span: Span, +) -> (Spanned, Expression) { + if !type_compatible(&arg.ty, &arg_shape.to_type()) { + working_set.error(ParseError::TypeMismatch( + arg_shape.to_type(), + arg.ty, + arg.span, + )); + ( + Spanned { + item: arg_name, + span: long_name_span, + }, + Expression::garbage(arg.span), + ) + } else { + ( + Spanned { + item: arg_name, + span: long_name_span, + }, + arg, + ) + } +} + fn parse_long_flag( working_set: &mut StateWorkingSet, spans: &[Span], @@ -368,25 +399,21 @@ fn parse_long_flag( span.start += long_name_len + 3; //offset by long flag and '=' let arg = parse_value(working_set, span, arg_shape); - - ( - Some(Spanned { - item: long_name, - span: Span::new(arg_span.start, arg_span.start + long_name_len + 2), - }), - Some(arg), - ) + let (arg_name, val_expression) = ensure_flag_arg_type( + working_set, + long_name, + arg, + arg_shape, + Span::new(arg_span.start, arg_span.start + long_name_len + 2), + ); + (Some(arg_name), Some(val_expression)) } else if let Some(arg) = spans.get(*spans_idx + 1) { let arg = parse_value(working_set, *arg, arg_shape); *spans_idx += 1; - ( - Some(Spanned { - item: long_name, - span: arg_span, - }), - Some(arg), - ) + let (arg_name, val_expression) = + ensure_flag_arg_type(working_set, long_name, arg, arg_shape, arg_span); + (Some(arg_name), Some(val_expression)) } else { working_set.error(ParseError::MissingFlagParam( arg_shape.to_string(), @@ -411,13 +438,14 @@ fn parse_long_flag( let arg = parse_value(working_set, span, &SyntaxShape::Boolean); - ( - Some(Spanned { - item: long_name, - span: Span::new(arg_span.start, arg_span.start + long_name_len + 2), - }), - Some(arg), - ) + let (arg_name, val_expression) = ensure_flag_arg_type( + working_set, + long_name, + arg, + &SyntaxShape::Boolean, + Span::new(arg_span.start, arg_span.start + long_name_len + 2), + ); + (Some(arg_name), Some(val_expression)) } else { ( Some(Spanned { diff --git a/src/tests/test_custom_commands.rs b/src/tests/test_custom_commands.rs index 3447d80f11..bc62006812 100644 --- a/src/tests/test_custom_commands.rs +++ b/src/tests/test_custom_commands.rs @@ -73,6 +73,14 @@ fn custom_switch1() -> TestResult { ) } +#[test] +fn custom_flag_with_type_checking() -> TestResult { + fail_test( + r#"def florb [--dry-run: int] { $dry_run }; let y = "3"; florb --dry-run=$y"#, + "type_mismatch", + ) +} + #[test] fn custom_switch2() -> TestResult { run_test( @@ -116,7 +124,7 @@ fn custom_flag1() -> TestResult { r#"def florb [ --age: int = 0 --name = "foobar" - ] { + ] { ($age | into string) + $name } florb"#,