diff --git a/crates/nu-command/src/generators/generate.rs b/crates/nu-command/src/generators/generate.rs index ead6bdd7c9..05a34dcbeb 100644 --- a/crates/nu-command/src/generators/generate.rs +++ b/crates/nu-command/src/generators/generate.rs @@ -11,10 +11,15 @@ impl Command for Generate { fn signature(&self) -> Signature { Signature::build("generate") - .input_output_types(vec![(Type::Nothing, Type::List(Box::new(Type::Any)))]) + .input_output_types(vec![ + (Type::Nothing, Type::list(Type::Any)), + (Type::list(Type::Any), Type::list(Type::Any)), + (Type::table(), Type::list(Type::Any)), + (Type::Range, Type::list(Type::Any)), + ]) .required( "closure", - SyntaxShape::Closure(Some(vec![SyntaxShape::Any])), + SyntaxShape::Closure(Some(vec![SyntaxShape::Any, SyntaxShape::Any])), "Generator function.", ) .optional("initial", SyntaxShape::Any, "Initial value.") @@ -31,11 +36,15 @@ impl Command for Generate { containing two optional keys: 'out' and 'next'. Each invocation, the 'out' value, if present, is added to the stream. If a 'next' key is present, it is used as the next argument to the closure, otherwise generation stops. -"# + +Additionally, if an input stream is provided, the generator closure accepts two +arguments. On each invocation an element of the input stream is provided as the +first argument. The second argument is the `next` value from the last invocation. +In this case, generation also stops when the input stream stops."# } fn search_terms(&self) -> Vec<&str> { - vec!["unfold", "stream", "yield", "expand"] + vec!["unfold", "stream", "yield", "expand", "state", "scan"] } fn examples(&self) -> Vec { @@ -68,6 +77,18 @@ used as the next argument to the closure, otherwise generation stops. "Generate a continuous stream of Fibonacci numbers, using default parameters", result: None, }, + Example { + example: + "1..5 | generate {|e, sum=0| let sum = $e + $sum; {out: $sum, next: $sum} }", + description: "Generate a running sum of the inputs", + result: Some(Value::test_list(vec![ + Value::test_int(1), + Value::test_int(3), + Value::test_int(6), + Value::test_int(10), + Value::test_int(15), + ])), + }, ] } @@ -76,7 +97,7 @@ used as the next argument to the closure, otherwise generation stops. engine_state: &EngineState, stack: &mut Stack, call: &Call, - _input: PipelineData, + input: PipelineData, ) -> Result { let head = call.head; let closure: Closure = call.req(engine_state, stack, 0)?; @@ -84,96 +105,55 @@ used as the next argument to the closure, otherwise generation stops. let block = engine_state.get_block(closure.block_id); let mut closure = ClosureEval::new(engine_state, stack, closure); - // A type of Option is used to represent state. Invocation - // will stop on None. Using Option allows functions to output - // one final value before stopping. - let mut state = Some(get_initial_state(initial, &block.signature, call.head)?); - let iter = std::iter::from_fn(move || { - let arg = state.take()?; + match input { + PipelineData::Empty => { + // A type of Option is used to represent state. Invocation + // will stop on None. Using Option allows functions to output + // one final value before stopping. + let mut state = Some(get_initial_state(initial, &block.signature, call.head)?); + let iter = std::iter::from_fn(move || { + let state_arg = state.take()?; - let (output, next_input) = match closure.run_with_value(arg) { - // no data -> output nothing and stop. - Ok(PipelineData::Empty) => (None, None), + let closure_result = closure + .add_arg(state_arg) + .run_with_input(PipelineData::Empty); + let (output, next_input) = parse_closure_result(closure_result, head); - Ok(PipelineData::Value(value, ..)) => { - let span = value.span(); - match value { - // {out: ..., next: ...} -> output and continue - Value::Record { val, .. } => { - let iter = val.into_owned().into_iter(); - let mut out = None; - let mut next = None; - let mut err = None; + // We use `state` to control when to stop, not `output`. By wrapping + // it in a `Some`, we allow the generator to output `None` as a valid output + // value. + state = next_input; + Some(output) + }); - for (k, v) in iter { - if k.eq_ignore_ascii_case("out") { - out = Some(v); - } else if k.eq_ignore_ascii_case("next") { - next = Some(v); - } else { - let error = ShellError::GenericError { - error: "Invalid block return".into(), - msg: format!("Unexpected record key '{}'", k), - span: Some(span), - help: None, - inner: vec![], - }; - err = Some(Value::error(error, head)); - break; - } - } - - if err.is_some() { - (err, None) - } else { - (out, next) - } - } - - // some other value -> error and stop - _ => { - let error = ShellError::GenericError { - error: "Invalid block return".into(), - msg: format!("Expected record, found {}", value.get_type()), - span: Some(span), - help: None, - inner: vec![], - }; - - (Some(Value::error(error, head)), None) - } - } - } - - Ok(other) => { - let error = other - .into_value(head) - .map(|val| ShellError::GenericError { - error: "Invalid block return".into(), - msg: format!("Expected record, found {}", val.get_type()), - span: Some(val.span()), - help: None, - inner: vec![], - }) - .unwrap_or_else(|err| err); - - (Some(Value::error(error, head)), None) - } - - // error -> error and stop - Err(error) => (Some(Value::error(error, head)), None), - }; - - // We use `state` to control when to stop, not `output`. By wrapping - // it in a `Some`, we allow the generator to output `None` as a valid output - // value. - state = next_input; - Some(output) - }); - - Ok(iter - .flatten() - .into_pipeline_data(call.head, engine_state.signals().clone())) + Ok(iter + .flatten() + .into_pipeline_data(call.head, engine_state.signals().clone())) + } + input @ (PipelineData::Value(Value::Range { .. }, ..) + | PipelineData::Value(Value::List { .. }, ..) + | PipelineData::ListStream(..)) => { + let mut state = Some(get_initial_state(initial, &block.signature, call.head)?); + let iter = input.into_iter().map_while(move |item| { + let state_arg = state.take()?; + let closure_result = closure + .add_arg(item) + .add_arg(state_arg) + .run_with_input(PipelineData::Empty); + let (output, next_input) = parse_closure_result(closure_result, head); + state = next_input; + Some(output) + }); + Ok(iter + .flatten() + .into_pipeline_data(call.head, engine_state.signals().clone())) + } + _ => Err(ShellError::PipelineMismatch { + exp_input_type: "nothing".to_string(), + dst_span: head, + src_span: input.span().unwrap_or(head), + }), + } } } @@ -209,6 +189,84 @@ fn get_initial_state( } } +fn parse_closure_result( + closure_result: Result, + head: Span, +) -> (Option, Option) { + match closure_result { + // no data -> output nothing and stop. + Ok(PipelineData::Empty) => (None, None), + + Ok(PipelineData::Value(value, ..)) => { + let span = value.span(); + match value { + // {out: ..., next: ...} -> output and continue + Value::Record { val, .. } => { + let iter = val.into_owned().into_iter(); + let mut out = None; + let mut next = None; + let mut err = None; + + for (k, v) in iter { + if k.eq_ignore_ascii_case("out") { + out = Some(v); + } else if k.eq_ignore_ascii_case("next") { + next = Some(v); + } else { + let error = ShellError::GenericError { + error: "Invalid block return".into(), + msg: format!("Unexpected record key '{}'", k), + span: Some(span), + help: None, + inner: vec![], + }; + err = Some(Value::error(error, head)); + break; + } + } + + if err.is_some() { + (err, None) + } else { + (out, next) + } + } + + // some other value -> error and stop + _ => { + let error = ShellError::GenericError { + error: "Invalid block return".into(), + msg: format!("Expected record, found {}", value.get_type()), + span: Some(span), + help: None, + inner: vec![], + }; + + (Some(Value::error(error, head)), None) + } + } + } + + Ok(other) => { + let error = other + .into_value(head) + .map(|val| ShellError::GenericError { + error: "Invalid block return".into(), + msg: format!("Expected record, found {}", val.get_type()), + span: Some(val.span()), + help: None, + inner: vec![], + }) + .unwrap_or_else(|err| err); + + (Some(Value::error(error, head)), None) + } + + // error -> error and stop + Err(error) => (Some(Value::error(error, head)), None), + } +} + #[cfg(test)] mod test { use super::*; diff --git a/crates/nu-command/tests/commands/generate.rs b/crates/nu-command/tests/commands/generate.rs index 9ae3b8823e..866d587dd6 100644 --- a/crates/nu-command/tests/commands/generate.rs +++ b/crates/nu-command/tests/commands/generate.rs @@ -148,3 +148,25 @@ fn generate_raise_error_on_no_default_parameter_closure_and_init_val() { )); assert!(actual.err.contains("The initial value is missing")); } + +#[test] +fn generate_allows_pipeline_input() { + let actual = nu!(r#"[1 2 3] | generate {|e, x=null| {out: $e, next: null}} | to nuon"#); + assert_eq!(actual.out, "[1, 2, 3]"); +} + +#[test] +fn generate_with_input_is_streaming() { + let actual = nu!(pipeline( + r#" + 1..10 + | each {|x| print -en $x; $x} + | generate {|e, sum=0| let sum = $e + $sum; {out: $sum, next: $sum}} + | first 5 + | to nuon + "# + )); + + assert_eq!(actual.out, "[1, 3, 6, 10, 15]"); + assert_eq!(actual.err, "12345"); +} diff --git a/crates/nu-std/std/iter/mod.nu b/crates/nu-std/std/iter/mod.nu index ca0aaafa4d..df5e03bd7b 100644 --- a/crates/nu-std/std/iter/mod.nu +++ b/crates/nu-std/std/iter/mod.nu @@ -107,16 +107,12 @@ export def scan [ # -> list init: any # initial value to seed the initial state fn: closure # the closure to perform the scan --noinit(-n) # remove the initial value from the result -] { - reduce --fold [$init] {|e, acc| - let acc_last = $acc | last - $acc ++ [($acc_last | do $fn $e $acc_last)] - } - | if $noinit { - $in | skip - } else { - $in - } +] { + generate {|e, acc| + let out = $acc | do $fn $e $acc + {next: $out, out: $out} + } $init + | if not $noinit { prepend $init } else { } } # Returns a list of values for which the supplied closure does not