diff --git a/crates/nu-command/src/default_context.rs b/crates/nu-command/src/default_context.rs index ae23e56e56..4d89d28ff1 100644 --- a/crates/nu-command/src/default_context.rs +++ b/crates/nu-command/src/default_context.rs @@ -382,6 +382,7 @@ pub fn add_shell_command_context(mut engine_state: EngineState) -> EngineState { Seq, SeqDate, SeqChar, + Unfold, }; // Hash diff --git a/crates/nu-command/src/example_test.rs b/crates/nu-command/src/example_test.rs index abbde34296..ff3edb3dc9 100644 --- a/crates/nu-command/src/example_test.rs +++ b/crates/nu-command/src/example_test.rs @@ -9,9 +9,9 @@ pub fn test_examples(cmd: impl Command + 'static) { #[cfg(test)] mod test_examples { use super::super::{ - Ansi, Date, Enumerate, Filter, Flatten, From, Get, Into, IntoDatetime, IntoString, Math, - MathRound, ParEach, Path, PathParse, Random, Sort, SortBy, Split, SplitColumn, SplitRow, - Str, StrJoin, StrLength, StrReplace, Update, Url, Values, Wrap, + Ansi, Date, Enumerate, Filter, First, Flatten, From, Get, Into, IntoDatetime, IntoString, + Math, MathRound, ParEach, Path, PathParse, Random, Sort, SortBy, Split, SplitColumn, + SplitRow, Str, StrJoin, StrLength, StrReplace, Update, Url, Values, Wrap, }; use crate::{Each, To}; use nu_cmd_lang::example_support::{ @@ -71,6 +71,7 @@ mod test_examples { working_set.add_decl(Box::new(Echo)); working_set.add_decl(Box::new(Enumerate)); working_set.add_decl(Box::new(Filter)); + working_set.add_decl(Box::new(First)); working_set.add_decl(Box::new(Flatten)); working_set.add_decl(Box::new(From)); working_set.add_decl(Box::new(Get)); diff --git a/crates/nu-command/src/generators/mod.rs b/crates/nu-command/src/generators/mod.rs index 0694a9d251..d243ed601c 100644 --- a/crates/nu-command/src/generators/mod.rs +++ b/crates/nu-command/src/generators/mod.rs @@ -2,8 +2,10 @@ mod cal; mod seq; mod seq_char; mod seq_date; +mod unfold; pub use cal::Cal; pub use seq::Seq; pub use seq_char::SeqChar; pub use seq_date::SeqDate; +pub use unfold::Unfold; diff --git a/crates/nu-command/src/generators/unfold.rs b/crates/nu-command/src/generators/unfold.rs new file mode 100644 index 0000000000..b777cf6d82 --- /dev/null +++ b/crates/nu-command/src/generators/unfold.rs @@ -0,0 +1,231 @@ +use itertools::unfold; + +use nu_engine::{eval_block_with_early_return, CallExt}; +use nu_protocol::ast::Call; +use nu_protocol::engine::{Closure, Command, EngineState, Stack}; +use nu_protocol::{ + Category, Example, IntoInterruptiblePipelineData, IntoPipelineData, PipelineData, ShellError, + Signature, Span, Spanned, SyntaxShape, Type, Value, +}; + +#[derive(Clone)] +pub struct Unfold; + +impl Command for Unfold { + fn name(&self) -> &str { + "unfold" + } + + fn signature(&self) -> Signature { + Signature::build("unfold") + .input_output_types(vec![ + (Type::Nothing, Type::List(Box::new(Type::Any))), + ( + Type::List(Box::new(Type::Any)), + Type::List(Box::new(Type::Any)), + ), + ]) + .required("initial", SyntaxShape::Any, "initial value") + .required( + "closure", + SyntaxShape::Closure(Some(vec![SyntaxShape::Any])), + "generator function", + ) + .allow_variants_without_examples(true) + .category(Category::Generators) + } + + fn usage(&self) -> &str { + "Generate a list of values by successively invoking a closure." + } + + fn extra_usage(&self) -> &str { + r#"The generator closure accepts a single argument and returns a record +containing two optional keys: 'out' and 'next'. Each invocation, the 'out' +value, if present, is added to the stream. If a 'next' key is present, it is +used as the next argument to the closure, otherwise generation stops. +"# + } + + fn search_terms(&self) -> Vec<&str> { + vec!["generate", "stream"] + } + + fn examples(&self) -> Vec { + vec![ + Example { + example: "unfold 0 {|i| if $i <= 10 { {out: $i, next: ($i + 2)} }}", + description: "Generate a sequence of numbers", + result: Some(Value::list( + vec![ + Value::test_int(0), + Value::test_int(2), + Value::test_int(4), + Value::test_int(6), + Value::test_int(8), + Value::test_int(10), + ], + Span::test_data(), + )), + }, + Example { + example: "unfold [0, 1] {|fib| {out: $fib.0, next: [$fib.1, ($fib.0 + $fib.1)]} } | first 10", + description: "Generate a stream of fibonacci numbers", + result: Some(Value::list( + vec![ + Value::test_int(0), + Value::test_int(1), + Value::test_int(1), + Value::test_int(2), + Value::test_int(3), + Value::test_int(5), + Value::test_int(8), + Value::test_int(13), + Value::test_int(21), + Value::test_int(34), + ], + Span::test_data(), + )), + }, + ] + } + + fn run( + &self, + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, + _input: PipelineData, + ) -> Result { + let initial: Value = call.req(engine_state, stack, 0)?; + let capture_block: Spanned = call.req(engine_state, stack, 1)?; + let block_span = capture_block.span; + let block = engine_state.get_block(capture_block.item.block_id).clone(); + let ctrlc = engine_state.ctrlc.clone(); + let engine_state = engine_state.clone(); + let mut stack = stack.captures_to_stack(&capture_block.item.captures); + let orig_env_vars = stack.env_vars.clone(); + let orig_env_hidden = stack.env_hidden.clone(); + let redirect_stdout = call.redirect_stdout; + let redirect_stderr = call.redirect_stderr; + + // A type of Option is used to represent state. Invocation + // will stop on None. Using Option allows functions to output + // one final value before stopping. + let iter = unfold(Some(initial), move |state| { + let arg = match state { + Some(state) => state.clone(), + None => return None, + }; + + // with_env() is used here to ensure that each iteration uses + // a different set of environment variables. + // Hence, a 'cd' in the first loop won't affect the next loop. + stack.with_env(&orig_env_vars, &orig_env_hidden); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, arg.clone()); + } + } + + let (output, next_input) = match eval_block_with_early_return( + &engine_state, + &mut stack, + &block, + arg.into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + // no data -> output nothing and stop. + Ok(PipelineData::Empty) => (None, None), + + Ok(PipelineData::Value(value, ..)) => { + let span = value.span(); + match value { + // {out: ..., next: ...} -> output and continue + Value::Record { val, .. } => { + let iter = val.into_iter(); + let mut out = None; + let mut next = None; + let mut err = None; + + for (k, v) in iter { + if k.to_lowercase() == "out" { + out = Some(v); + } else if k.to_lowercase() == "next" { + next = Some(v); + } else { + let error = ShellError::GenericError( + "Invalid block return".to_string(), + format!("Unexpected record key '{}'", k), + Some(span), + None, + Vec::new(), + ); + err = Some(Value::error(error, block_span)); + break; + } + } + + if err.is_some() { + (err, None) + } else { + (out, next) + } + } + + // some other value -> error and stop + _ => { + let error = ShellError::GenericError( + "Invalid block return".to_string(), + format!("Expected record, found {}", value.get_type()), + Some(span), + None, + Vec::new(), + ); + + (Some(Value::error(error, block_span)), None) + } + } + } + + Ok(other) => { + let val = other.into_value(block_span); + let error = ShellError::GenericError( + "Invalid block return".to_string(), + format!("Expected record, found {}", val.get_type()), + Some(val.span()), + None, + Vec::new(), + ); + + (Some(Value::error(error, block_span)), None) + } + + // error -> error and stop + Err(error) => (Some(Value::error(error, block_span)), None), + }; + + // We use `state` to control when to stop, not `output`. By wrapping + // it in a `Some`, we allow the generator to output `None` as a valid output + // value. + *state = next_input; + Some(output) + }); + + Ok(iter.flatten().into_pipeline_data(ctrlc)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_examples() { + use crate::test_examples; + + test_examples(Unfold {}) + } +} diff --git a/crates/nu-command/tests/commands/mod.rs b/crates/nu-command/tests/commands/mod.rs index 8c45747284..2a3afaab07 100644 --- a/crates/nu-command/tests/commands/mod.rs +++ b/crates/nu-command/tests/commands/mod.rs @@ -103,6 +103,7 @@ mod touch; mod transpose; mod try_; mod ucp; +mod unfold; mod uniq; mod uniq_by; mod update; diff --git a/crates/nu-command/tests/commands/unfold.rs b/crates/nu-command/tests/commands/unfold.rs new file mode 100644 index 0000000000..acbdfa342c --- /dev/null +++ b/crates/nu-command/tests/commands/unfold.rs @@ -0,0 +1,102 @@ +use nu_test_support::{nu, pipeline}; + +#[test] +fn unfold_no_next_break() { + let actual = + nu!("unfold 1 {|x| if $x == 3 { {out: $x}} else { {out: $x, next: ($x + 1)} }} | to nuon"); + + assert_eq!(actual.out, "[1, 2, 3]"); +} + +#[test] +fn unfold_null_break() { + let actual = nu!("unfold 1 {|x| if $x <= 3 { {out: $x, next: ($x + 1)} }} | to nuon"); + + assert_eq!(actual.out, "[1, 2, 3]"); +} + +#[test] +fn unfold_allows_empty_output() { + let actual = nu!(pipeline( + r#" + unfold 0 {|x| + if $x == 1 { + {next: ($x + 1)} + } else if $x < 3 { + {out: $x, next: ($x + 1)} + } + } | to nuon + "# + )); + + assert_eq!(actual.out, "[0, 2]"); +} + +#[test] +fn unfold_allows_no_output() { + let actual = nu!(pipeline( + r#" + unfold 0 {|x| + if $x < 3 { + {next: ($x + 1)} + } + } | to nuon + "# + )); + + assert_eq!(actual.out, "[]"); +} + +#[test] +fn unfold_allows_null_state() { + let actual = nu!(pipeline( + r#" + unfold 0 {|x| + if $x == null { + {out: "done"} + } else if $x < 1 { + {out: "going", next: ($x + 1)} + } else { + {out: "stopping", next: null} + } + } | to nuon + "# + )); + + assert_eq!(actual.out, "[going, stopping, done]"); +} + +#[test] +fn unfold_allows_null_output() { + let actual = nu!(pipeline( + r#" + unfold 0 {|x| + if $x == 3 { + {out: "done"} + } else { + {out: null, next: ($x + 1)} + } + } | to nuon + "# + )); + + assert_eq!(actual.out, "[null, null, null, done]"); +} + +#[test] +fn unfold_disallows_extra_keys() { + let actual = nu!("unfold 0 {|x| {foo: bar, out: $x}}"); + assert!(actual.err.contains("Invalid block return")); +} + +#[test] +fn unfold_disallows_list() { + let actual = nu!("unfold 0 {|x| [$x, ($x + 1)]}"); + assert!(actual.err.contains("Invalid block return")); +} + +#[test] +fn unfold_disallows_primitive() { + let actual = nu!("unfold 0 {|x| 1}"); + assert!(actual.err.contains("Invalid block return")); +}