add match guards (#9621)

## description

this pr adds [match
guards](https://doc.rust-lang.org/reference/expressions/match-expr.html#match-guards)
to match patterns
```nushell
match $x {
   _ if $x starts-with 'nu' => {},
   $x => {}
}
```

these work pretty much like rust's match guards, with few limitations:

1. multiple matches using the `|` are not (yet?) supported
 
```nushell
match $num {
    0 | _ if (is-odd $num) => {},
    _ => {}
}
```

2. blocks cannot be used as guards, (yet?)

```nushell
match $num {
    $x if { $x ** $x == inf } => {},
     _ => {}
}
```

## checklist
- [x] syntax
- [x] syntax highlighting[^1]
- [x] semantics
- [x] tests
- [x] clean up

[^1]: defered for another pr
This commit is contained in:
mike 2023-07-16 03:25:12 +03:00 committed by GitHub
parent 57d96c09fa
commit 5bfec20244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 26 deletions

View File

@ -1,4 +1,4 @@
use nu_engine::{eval_block, eval_expression_with_input, CallExt}; use nu_engine::{eval_block, eval_expression, eval_expression_with_input, CallExt};
use nu_protocol::ast::{Call, Expr, Expression}; use nu_protocol::ast::{Call, Expr, Expression};
use nu_protocol::engine::{Command, EngineState, Matcher, Stack}; use nu_protocol::engine::{Command, EngineState, Matcher, Stack};
use nu_protocol::{ use nu_protocol::{
@ -52,18 +52,29 @@ impl Command for Match {
stack.add_var(match_variable.0, match_variable.1); stack.add_var(match_variable.0, match_variable.1);
} }
if let Some(block_id) = match_.1.as_block() { let guard_matches = if let Some(guard) = &match_.0.guard {
let Value::Bool { val, .. } = eval_expression(engine_state, stack, guard)? else {
return Err(ShellError::MatchGuardNotBool { span: guard.span});
};
val
} else {
true
};
if guard_matches {
return if let Some(block_id) = match_.1.as_block() {
let block = engine_state.get_block(block_id); let block = engine_state.get_block(block_id);
return eval_block( eval_block(
engine_state, engine_state,
stack, stack,
block, block,
input, input,
call.redirect_stdout, call.redirect_stdout,
call.redirect_stderr, call.redirect_stderr,
); )
} else { } else {
return eval_expression_with_input( eval_expression_with_input(
engine_state, engine_state,
stack, stack,
&match_.1, &match_.1,
@ -71,7 +82,8 @@ impl Command for Match {
call.redirect_stdout, call.redirect_stdout,
call.redirect_stderr, call.redirect_stderr,
) )
.map(|x| x.0); .map(|x| x.0)
};
} }
} }
} }
@ -107,6 +119,16 @@ impl Command for Match {
example: "{a: {b: 3}} | match $in {{a: { $b }} => ($b + 10) }", example: "{a: {b: 3}} | match $in {{a: { $b }} => ($b + 10) }",
result: Some(Value::test_int(13)), result: Some(Value::test_int(13)),
}, },
Example {
description: "Match with a guard",
example: "
match [1 2 3] {
[$x, ..$y] if $x == 1 => { 'good list' },
_ => { 'not a very good list' }
}
",
result: Some(Value::test_string("good list")),
},
] ]
} }
} }

View File

@ -197,3 +197,54 @@ fn match_doesnt_overwrite_variable() {
// As we do not auto-print loops anymore // As we do not auto-print loops anymore
assert_eq!(actual.out, "100"); assert_eq!(actual.out, "100");
} }
#[test]
fn match_with_guard() {
let actual = nu!(
cwd: ".",
"match [1 2 3] { [$x, ..] if $x mod 2 == 0 => { $x }, $x => { 2 } }"
);
assert_eq!(actual.out, "2");
}
#[test]
fn match_with_guard_block_as_guard() {
// this should work?
let actual = nu!(
cwd: ".",
"match 4 { $x if { $x + 20 > 25 } => { 'good num' }, _ => { 'terrible num' } }"
);
assert!(actual.err.contains("Match guard not bool"));
}
#[test]
fn match_with_guard_parens_expr_as_guard() {
let actual = nu!(
cwd: ".",
"match 4 { $x if ($x + 20 > 25) => { 'good num' }, _ => { 'terrible num' } }"
);
assert_eq!(actual.out, "terrible num");
}
#[test]
fn match_with_guard_not_bool() {
let actual = nu!(
cwd: ".",
"match 4 { $x if $x + 1 => { 'err!()' }, _ => { 'unreachable!()' } }"
);
assert!(actual.err.contains("Match guard not bool"));
}
#[test]
fn match_with_guard_no_expr_after_if() {
let actual = nu!(
cwd: ".",
"match 4 { $x if => { 'err!()' }, _ => { 'unreachable!()' } }"
);
assert!(actual.err.contains("Match guard without an expression"));
}

View File

@ -13,6 +13,7 @@ use crate::{
pub fn garbage(span: Span) -> MatchPattern { pub fn garbage(span: Span) -> MatchPattern {
MatchPattern { MatchPattern {
pattern: Pattern::Garbage, pattern: Pattern::Garbage,
guard: None,
span, span,
} }
} }
@ -45,6 +46,7 @@ pub fn parse_pattern(working_set: &mut StateWorkingSet, span: Span) -> MatchPatt
} else if bytes == b"_" { } else if bytes == b"_" {
MatchPattern { MatchPattern {
pattern: Pattern::IgnoreValue, pattern: Pattern::IgnoreValue,
guard: None,
span, span,
} }
} else { } else {
@ -53,6 +55,7 @@ pub fn parse_pattern(working_set: &mut StateWorkingSet, span: Span) -> MatchPatt
MatchPattern { MatchPattern {
pattern: Pattern::Value(value), pattern: Pattern::Value(value),
guard: None,
span, span,
} }
} }
@ -78,6 +81,7 @@ pub fn parse_variable_pattern(working_set: &mut StateWorkingSet, span: Span) ->
if let Some(var_id) = parse_variable_pattern_helper(working_set, span) { if let Some(var_id) = parse_variable_pattern_helper(working_set, span) {
MatchPattern { MatchPattern {
pattern: Pattern::Variable(var_id), pattern: Pattern::Variable(var_id),
guard: None,
span, span,
} }
} else { } else {
@ -126,6 +130,7 @@ pub fn parse_list_pattern(working_set: &mut StateWorkingSet, span: Span) -> Matc
if contents == b".." { if contents == b".." {
args.push(MatchPattern { args.push(MatchPattern {
pattern: Pattern::IgnoreRest, pattern: Pattern::IgnoreRest,
guard: None,
span: command.parts[spans_idx], span: command.parts[spans_idx],
}); });
break; break;
@ -139,6 +144,7 @@ pub fn parse_list_pattern(working_set: &mut StateWorkingSet, span: Span) -> Matc
) { ) {
args.push(MatchPattern { args.push(MatchPattern {
pattern: Pattern::Rest(var_id), pattern: Pattern::Rest(var_id),
guard: None,
span: command.parts[spans_idx], span: command.parts[spans_idx],
}); });
break; break;
@ -163,6 +169,7 @@ pub fn parse_list_pattern(working_set: &mut StateWorkingSet, span: Span) -> Matc
MatchPattern { MatchPattern {
pattern: Pattern::List(args), pattern: Pattern::List(args),
guard: None,
span, span,
} }
} }
@ -232,6 +239,7 @@ pub fn parse_record_pattern(working_set: &mut StateWorkingSet, span: Span) -> Ma
MatchPattern { MatchPattern {
pattern: Pattern::Record(output), pattern: Pattern::Record(output),
guard: None,
span, span,
} }
} }

View File

@ -4270,8 +4270,9 @@ pub fn parse_match_block_expression(working_set: &mut StateWorkingSet, span: Spa
break; break;
} }
// Multiple patterns connected by '|'
let mut connector = working_set.get_span_contents(output[position].span); let mut connector = working_set.get_span_contents(output[position].span);
// Multiple patterns connected by '|'
if connector == b"|" && position < output.len() { if connector == b"|" && position < output.len() {
let mut or_pattern = vec![pattern]; let mut or_pattern = vec![pattern];
@ -4322,10 +4323,56 @@ pub fn parse_match_block_expression(working_set: &mut StateWorkingSet, span: Spa
pattern = MatchPattern { pattern = MatchPattern {
pattern: Pattern::Or(or_pattern), pattern: Pattern::Or(or_pattern),
guard: None,
span: Span::new(start, end), span: Span::new(start, end),
} }
// A match guard
} else if connector == b"if" {
let if_end = {
let end = output[position].span.end;
Span::new(end, end)
};
position += 1;
let mk_err = || ParseError::LabeledErrorWithHelp {
error: "Match guard without an expression".into(),
label: "expected an expression".into(),
help: "The `if` keyword must be followed with an expression".into(),
span: if_end,
};
if output.get(position).is_none() {
working_set.error(mk_err());
return garbage(span);
};
let (tokens, found) = if let Some((pos, _)) = output[position..]
.iter()
.find_position(|t| working_set.get_span_contents(t.span) == b"=>")
{
if position + pos == position {
working_set.error(mk_err());
return garbage(span);
} }
(&output[position..position + pos], true)
} else {
(&output[position..], false)
};
let mut start = 0;
let guard = parse_multispan_value(
working_set,
&tokens.iter().map(|tok| tok.span).collect_vec(),
&mut start,
&SyntaxShape::MathExpression,
);
pattern.guard = Some(guard);
position += if found { start + 1 } else { start };
connector = working_set.get_span_contents(output[position].span);
}
// Then the `=>` arrow // Then the `=>` arrow
if connector != b"=>" { if connector != b"=>" {
working_set.error(ParseError::Mismatch( working_set.error(ParseError::Mismatch(

View File

@ -1,12 +1,11 @@
use serde::{Deserialize, Serialize};
use crate::{Span, VarId};
use super::Expression; use super::Expression;
use crate::{Span, VarId};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MatchPattern { pub struct MatchPattern {
pub pattern: Pattern, pub pattern: Pattern,
pub guard: Option<Expression>,
pub span: Span, pub span: Span,
} }

View File

@ -1073,6 +1073,18 @@ pub enum ShellError {
#[label("This operation was interrupted")] #[label("This operation was interrupted")]
span: Option<Span>, span: Option<Span>,
}, },
/// An attempt to use, as a match guard, an expression that
/// does not resolve into a boolean
#[error("Match guard not bool")]
#[diagnostic(
code(nu::shell::match_guard_not_bool),
help("Match guards should evaluate to a boolean")
)]
MatchGuardNotBool {
#[label("not a boolean expression")]
span: Span,
},
} }
impl From<std::io::Error> for ShellError { impl From<std::io::Error> for ShellError {