From 186da4d7251287c8c3981dfea817e9dc2cb19d59 Mon Sep 17 00:00:00 2001 From: JT <547158+jntrnr@users.noreply.github.com> Date: Wed, 12 Jan 2022 15:06:56 +1100 Subject: [PATCH] Fixing captures (#723) * WIP fixing captures * small fix * WIP * Rewrite to proof-of-concept better parse_def * Add missing file * Finish capture refactor * Fix tests * Add more tests --- crates/nu-command/src/core_commands/do_.rs | 14 +- crates/nu-command/src/core_commands/for_.rs | 12 +- crates/nu-command/src/core_commands/if_.rs | 20 ++- crates/nu-command/src/env/with_env.rs | 12 +- crates/nu-command/src/filters/all.rs | 13 +- crates/nu-command/src/filters/any.rs | 12 +- crates/nu-command/src/filters/collect.rs | 12 +- crates/nu-command/src/filters/each.rs | 20 +-- crates/nu-command/src/filters/keep/until.rs | 13 +- crates/nu-command/src/filters/keep/while_.rs | 13 +- crates/nu-command/src/filters/par_each.rs | 12 +- crates/nu-command/src/filters/reduce.rs | 15 +- crates/nu-command/src/filters/skip/until.rs | 13 +- crates/nu-command/src/filters/skip/while_.rs | 13 +- crates/nu-command/src/filters/update.rs | 13 +- crates/nu-command/src/filters/where_.rs | 16 +- crates/nu-command/src/system/benchmark.rs | 12 +- crates/nu-engine/src/env.rs | 4 +- crates/nu-engine/src/eval.rs | 25 ++- crates/nu-parser/src/flatten.rs | 5 +- crates/nu-parser/src/parse_keywords.rs | 162 +++++++++++++++++- crates/nu-parser/src/parser.rs | 43 ++++- crates/nu-protocol/src/ast/expression.rs | 4 + .../nu-protocol/src/engine/capture_block.rs | 9 + crates/nu-protocol/src/engine/engine_state.rs | 1 + crates/nu-protocol/src/engine/mod.rs | 2 + crates/nu-protocol/src/engine/stack.rs | 24 ++- crates/nu-protocol/src/value/from_value.rs | 40 +++++ crates/nu-protocol/src/value/mod.rs | 10 +- src/tests/test_engine.rs | 24 +++ 30 files changed, 424 insertions(+), 164 deletions(-) create mode 100644 crates/nu-protocol/src/engine/capture_block.rs diff --git a/crates/nu-command/src/core_commands/do_.rs b/crates/nu-command/src/core_commands/do_.rs index 8c0e77d52f..8d2a87e329 100644 --- a/crates/nu-command/src/core_commands/do_.rs +++ b/crates/nu-command/src/core_commands/do_.rs @@ -1,6 +1,6 @@ use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{Category, PipelineData, Signature, SyntaxShape, Value}; #[derive(Clone)] @@ -39,16 +39,12 @@ impl Command for Do { call: &Call, input: PipelineData, ) -> Result { - let block: Value = call.req(engine_state, stack, 0)?; - let block_id = block.as_block()?; - + let block: CaptureBlock = call.req(engine_state, stack, 0)?; + let rest: Vec = call.rest(engine_state, stack, 1)?; let ignore_errors = call.has_flag("ignore-errors"); - let rest: Vec = call.rest(engine_state, stack, 1)?; - - let block = engine_state.get_block(block_id); - - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&block.captures); + let block = engine_state.get_block(block.block_id); let params: Vec<_> = block .signature diff --git a/crates/nu-command/src/core_commands/for_.rs b/crates/nu-command/src/core_commands/for_.rs index 2e71718210..620a374d26 100644 --- a/crates/nu-command/src/core_commands/for_.rs +++ b/crates/nu-command/src/core_commands/for_.rs @@ -1,6 +1,6 @@ -use nu_engine::{eval_block, eval_expression}; +use nu_engine::{eval_block, eval_expression, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{ Category, Example, IntoInterruptiblePipelineData, PipelineData, Signature, Span, SyntaxShape, Value, @@ -61,16 +61,14 @@ impl Command for For { .expect("internal error: missing keyword"); let values = eval_expression(engine_state, stack, keyword_expr)?; - let block_id = call.positional[2] - .as_block() - .expect("internal error: expected block"); + let capture_block: CaptureBlock = call.req(engine_state, stack, 2)?; let numbered = call.has_flag("numbered"); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); - let block = engine_state.get_block(block_id).clone(); - let mut stack = stack.collect_captures(&block.captures); + let block = engine_state.get_block(capture_block.block_id).clone(); + let mut stack = stack.captures_to_stack(&capture_block.captures); let orig_env_vars = stack.env_vars.clone(); let orig_env_hidden = stack.env_hidden.clone(); diff --git a/crates/nu-command/src/core_commands/if_.rs b/crates/nu-command/src/core_commands/if_.rs index df91d7537e..2f607eae66 100644 --- a/crates/nu-command/src/core_commands/if_.rs +++ b/crates/nu-command/src/core_commands/if_.rs @@ -1,8 +1,9 @@ -use nu_engine::{eval_block, eval_expression}; +use nu_engine::{eval_block, eval_expression, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{ - Category, Example, IntoPipelineData, PipelineData, ShellError, Signature, SyntaxShape, Value, + Category, Example, FromValue, IntoPipelineData, PipelineData, ShellError, Signature, + SyntaxShape, Value, }; #[derive(Clone)] @@ -41,23 +42,24 @@ impl Command for If { input: PipelineData, ) -> Result { let cond = &call.positional[0]; - let then_block = call.positional[1] - .as_block() - .expect("internal error: expected block"); + let then_block: CaptureBlock = call.req(engine_state, stack, 1)?; let else_case = call.positional.get(2); let result = eval_expression(engine_state, stack, cond)?; match &result { Value::Bool { val, .. } => { if *val { - let block = engine_state.get_block(then_block); - let mut stack = stack.collect_captures(&block.captures); + let block = engine_state.get_block(then_block.block_id); + let mut stack = stack.captures_to_stack(&then_block.captures); eval_block(engine_state, &mut stack, block, input) } else if let Some(else_case) = else_case { if let Some(else_expr) = else_case.as_keyword() { if let Some(block_id) = else_expr.as_block() { + let result = eval_expression(engine_state, stack, else_expr)?; + let else_block: CaptureBlock = FromValue::from_value(&result)?; + + let mut stack = stack.captures_to_stack(&else_block.captures); let block = engine_state.get_block(block_id); - let mut stack = stack.collect_captures(&block.captures); eval_block(engine_state, &mut stack, block, input) } else { eval_expression(engine_state, stack, else_expr) diff --git a/crates/nu-command/src/env/with_env.rs b/crates/nu-command/src/env/with_env.rs index 2ad0c877e2..265796fc18 100644 --- a/crates/nu-command/src/env/with_env.rs +++ b/crates/nu-command/src/env/with_env.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use nu_engine::{eval_block, CallExt}; use nu_protocol::{ ast::Call, - engine::{Command, EngineState, Stack}, + engine::{CaptureBlock, Command, EngineState, Stack}, Category, Example, PipelineData, ShellError, Signature, SyntaxShape, Value, }; @@ -79,11 +79,9 @@ fn with_env( // let external_redirection = args.call_info.args.external_redirection; let variable: Value = call.req(engine_state, stack, 0)?; - let block_id = call.positional[1] - .as_block() - .expect("internal error: expected block"); - let block = engine_state.get_block(block_id).clone(); - let mut stack = stack.collect_captures(&block.captures); + let capture_block: CaptureBlock = call.req(engine_state, stack, 1)?; + let block = engine_state.get_block(capture_block.block_id); + let mut stack = stack.captures_to_stack(&capture_block.captures); let mut env: HashMap = HashMap::new(); @@ -134,7 +132,7 @@ fn with_env( stack.add_env_var(k, v); } - eval_block(engine_state, &mut stack, &block, input) + eval_block(engine_state, &mut stack, block, input) } #[cfg(test)] diff --git a/crates/nu-command/src/filters/all.rs b/crates/nu-command/src/filters/all.rs index 6d6321d13d..cdc076262b 100644 --- a/crates/nu-command/src/filters/all.rs +++ b/crates/nu-command/src/filters/all.rs @@ -1,7 +1,7 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::{ ast::Call, - engine::{Command, EngineState, Stack}, + engine::{CaptureBlock, Command, EngineState, Stack}, Category, Example, IntoPipelineData, PipelineData, ShellError, Signature, SyntaxShape, Value, }; @@ -49,16 +49,15 @@ impl Command for All { call: &Call, input: PipelineData, ) -> Result { - let predicate = &call.positional[0]; + // let predicate = &call.positional[0]; let span = call.head; - let block_id = predicate - .as_row_condition_block() - .ok_or_else(|| ShellError::TypeMismatch("expected row condition".to_owned(), span))?; + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; + let block_id = capture_block.block_id; let block = engine_state.get_block(block_id); let var_id = block.signature.get_positional(0).and_then(|arg| arg.var_id); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); diff --git a/crates/nu-command/src/filters/any.rs b/crates/nu-command/src/filters/any.rs index 869664d203..b66507f5c5 100644 --- a/crates/nu-command/src/filters/any.rs +++ b/crates/nu-command/src/filters/any.rs @@ -1,7 +1,7 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::{ ast::Call, - engine::{Command, EngineState, Stack}, + engine::{CaptureBlock, Command, EngineState, Stack}, Category, Example, IntoPipelineData, PipelineData, ShellError, Signature, SyntaxShape, Value, }; @@ -49,16 +49,14 @@ impl Command for Any { call: &Call, input: PipelineData, ) -> Result { - let predicate = &call.positional[0]; let span = call.head; - let block_id = predicate - .as_row_condition_block() - .ok_or_else(|| ShellError::TypeMismatch("expected row condition".to_owned(), span))?; + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; + let block_id = capture_block.block_id; let block = engine_state.get_block(block_id); let var_id = block.signature.get_positional(0).and_then(|arg| arg.var_id); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); diff --git a/crates/nu-command/src/filters/collect.rs b/crates/nu-command/src/filters/collect.rs index b7dd98125c..c9cc64dcb6 100644 --- a/crates/nu-command/src/filters/collect.rs +++ b/crates/nu-command/src/filters/collect.rs @@ -1,6 +1,6 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{Category, Example, PipelineData, Signature, SyntaxShape, Value}; #[derive(Clone)] @@ -32,12 +32,10 @@ impl Command for Collect { call: &Call, input: PipelineData, ) -> Result { - let block_id = call.positional[0] - .as_block() - .expect("internal error: expected block"); + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; - let block = engine_state.get_block(block_id).clone(); - let mut stack = stack.collect_captures(&block.captures); + let block = engine_state.get_block(capture_block.block_id).clone(); + let mut stack = stack.captures_to_stack(&capture_block.captures); let input: Value = input.into_value(call.head); diff --git a/crates/nu-command/src/filters/each.rs b/crates/nu-command/src/filters/each.rs index 84ca535961..d67e939904 100644 --- a/crates/nu-command/src/filters/each.rs +++ b/crates/nu-command/src/filters/each.rs @@ -1,6 +1,6 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{ Category, Example, IntoInterruptiblePipelineData, IntoPipelineData, PipelineData, Signature, Span, SyntaxShape, Value, @@ -62,15 +62,13 @@ impl Command for Each { call: &Call, input: PipelineData, ) -> Result { - let block_id = call.positional[0] - .as_block() - .expect("internal error: expected block"); + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; let numbered = call.has_flag("numbered"); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); - let block = engine_state.get_block(block_id).clone(); - let mut stack = stack.collect_captures(&block.captures); + let block = engine_state.get_block(capture_block.block_id).clone(); + let mut stack = stack.captures_to_stack(&capture_block.captures); let orig_env_vars = stack.env_vars.clone(); let orig_env_hidden = stack.env_hidden.clone(); let span = call.head; @@ -198,7 +196,7 @@ impl Command for Each { let mut output_vals = vec![]; for (col, val) in cols.into_iter().zip(vals.into_iter()) { - let block = engine_state.get_block(block_id); + //let block = engine_state.get_block(block_id); stack.with_env(&orig_env_vars, &orig_env_hidden); @@ -221,7 +219,7 @@ impl Command for Each { } } - match eval_block(&engine_state, &mut stack, block, PipelineData::new(span))? { + match eval_block(&engine_state, &mut stack, &block, PipelineData::new(span))? { PipelineData::Value( Value::Record { mut cols, mut vals, .. @@ -247,7 +245,7 @@ impl Command for Each { .into_pipeline_data()) } PipelineData::Value(x, ..) => { - let block = engine_state.get_block(block_id); + //let block = engine_state.get_block(block_id); if let Some(var) = block.signature.get_positional(0) { if let Some(var_id) = &var.var_id { @@ -255,7 +253,7 @@ impl Command for Each { } } - eval_block(&engine_state, &mut stack, block, PipelineData::new(span)) + eval_block(&engine_state, &mut stack, &block, PipelineData::new(span)) } } } diff --git a/crates/nu-command/src/filters/keep/until.rs b/crates/nu-command/src/filters/keep/until.rs index 513ffea2e0..0e1921f0a8 100644 --- a/crates/nu-command/src/filters/keep/until.rs +++ b/crates/nu-command/src/filters/keep/until.rs @@ -1,7 +1,7 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::{ ast::Call, - engine::{Command, EngineState, Stack}, + engine::{CaptureBlock, Command, EngineState, Stack}, Category, Example, IntoInterruptiblePipelineData, PipelineData, ShellError, Signature, Span, SyntaxShape, Value, }; @@ -48,15 +48,12 @@ impl Command for KeepUntil { ) -> Result { let span = call.head; - let predicate = &call.positional[0]; - let block_id = predicate - .as_row_condition_block() - .ok_or_else(|| ShellError::TypeMismatch("expected row condition".to_owned(), span))?; + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; - let block = engine_state.get_block(block_id).clone(); + let block = engine_state.get_block(capture_block.block_id).clone(); let var_id = block.signature.get_positional(0).and_then(|arg| arg.var_id); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); diff --git a/crates/nu-command/src/filters/keep/while_.rs b/crates/nu-command/src/filters/keep/while_.rs index 991eb2ebdc..b8d036dd4f 100644 --- a/crates/nu-command/src/filters/keep/while_.rs +++ b/crates/nu-command/src/filters/keep/while_.rs @@ -1,7 +1,7 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::{ ast::Call, - engine::{Command, EngineState, Stack}, + engine::{CaptureBlock, Command, EngineState, Stack}, Category, Example, IntoInterruptiblePipelineData, PipelineData, ShellError, Signature, Span, SyntaxShape, Value, }; @@ -48,15 +48,12 @@ impl Command for KeepWhile { ) -> Result { let span = call.head; - let predicate = &call.positional[0]; - let block_id = predicate - .as_row_condition_block() - .ok_or_else(|| ShellError::TypeMismatch("expected row condition".to_owned(), span))?; + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; - let block = engine_state.get_block(block_id).clone(); + let block = engine_state.get_block(capture_block.block_id).clone(); let var_id = block.signature.get_positional(0).and_then(|arg| arg.var_id); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); diff --git a/crates/nu-command/src/filters/par_each.rs b/crates/nu-command/src/filters/par_each.rs index 0344a6e937..a3f0b11ec8 100644 --- a/crates/nu-command/src/filters/par_each.rs +++ b/crates/nu-command/src/filters/par_each.rs @@ -1,6 +1,6 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{ Category, Example, IntoInterruptiblePipelineData, IntoPipelineData, PipelineData, Signature, SyntaxShape, Value, @@ -45,15 +45,13 @@ impl Command for ParEach { call: &Call, input: PipelineData, ) -> Result { - let block_id = call.positional[0] - .as_block() - .expect("internal error: expected block"); + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; let numbered = call.has_flag("numbered"); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); - let block = engine_state.get_block(block_id); - let mut stack = stack.collect_captures(&block.captures); + let block_id = capture_block.block_id; + let mut stack = stack.captures_to_stack(&capture_block.captures); let span = call.head; match input { diff --git a/crates/nu-command/src/filters/reduce.rs b/crates/nu-command/src/filters/reduce.rs index 6044e5829b..e7a0c70c87 100644 --- a/crates/nu-command/src/filters/reduce.rs +++ b/crates/nu-command/src/filters/reduce.rs @@ -1,7 +1,7 @@ use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{ Example, IntoPipelineData, PipelineData, ShellError, Signature, Span, SyntaxShape, Value, }; @@ -102,17 +102,10 @@ impl Command for Reduce { let fold: Option = call.get_flag(engine_state, stack, "fold")?; let numbered = call.has_flag("numbered"); - let block = if let Some(block_id) = call.nth(0).and_then(|b| b.as_block()) { - engine_state.get_block(block_id) - } else { - return Err(ShellError::SpannedLabeledError( - "Internal Error".to_string(), - "expected block".to_string(), - span, - )); - }; + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; + let mut stack = stack.captures_to_stack(&capture_block.captures); + let block = engine_state.get_block(capture_block.block_id); - let mut stack = stack.collect_captures(&block.captures); let orig_env_vars = stack.env_vars.clone(); let orig_env_hidden = stack.env_hidden.clone(); diff --git a/crates/nu-command/src/filters/skip/until.rs b/crates/nu-command/src/filters/skip/until.rs index 1a65aa3b58..0f21b40197 100644 --- a/crates/nu-command/src/filters/skip/until.rs +++ b/crates/nu-command/src/filters/skip/until.rs @@ -1,7 +1,7 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::{ ast::Call, - engine::{Command, EngineState, Stack}, + engine::{CaptureBlock, Command, EngineState, Stack}, Category, Example, IntoInterruptiblePipelineData, PipelineData, ShellError, Signature, Span, SyntaxShape, Value, }; @@ -46,16 +46,13 @@ impl Command for SkipUntil { call: &Call, input: PipelineData, ) -> Result { - let predicate = &call.positional[0]; let span = call.head; - let block_id = predicate - .as_row_condition_block() - .ok_or_else(|| ShellError::TypeMismatch("expected row condition".to_owned(), span))?; + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; - let block = engine_state.get_block(block_id).clone(); + let block = engine_state.get_block(capture_block.block_id).clone(); let var_id = block.signature.get_positional(0).and_then(|arg| arg.var_id); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); diff --git a/crates/nu-command/src/filters/skip/while_.rs b/crates/nu-command/src/filters/skip/while_.rs index 55bba5caa2..2c0b49cf79 100644 --- a/crates/nu-command/src/filters/skip/while_.rs +++ b/crates/nu-command/src/filters/skip/while_.rs @@ -1,7 +1,7 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::{ ast::Call, - engine::{Command, EngineState, Stack}, + engine::{CaptureBlock, Command, EngineState, Stack}, Category, Example, IntoInterruptiblePipelineData, PipelineData, ShellError, Signature, Span, SyntaxShape, Value, }; @@ -46,16 +46,13 @@ impl Command for SkipWhile { call: &Call, input: PipelineData, ) -> Result { - let predicate = &call.positional[0]; let span = call.head; - let block_id = predicate - .as_row_condition_block() - .ok_or_else(|| ShellError::TypeMismatch("expected row condition".to_owned(), span))?; + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; - let block = engine_state.get_block(block_id).clone(); + let block = engine_state.get_block(capture_block.block_id).clone(); let var_id = block.signature.get_positional(0).and_then(|arg| arg.var_id); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); diff --git a/crates/nu-command/src/filters/update.rs b/crates/nu-command/src/filters/update.rs index 2ee9cc9f65..9b01be148a 100644 --- a/crates/nu-command/src/filters/update.rs +++ b/crates/nu-command/src/filters/update.rs @@ -1,9 +1,9 @@ use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::{Call, CellPath}; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{ - Category, Example, IntoPipelineData, PipelineData, ShellError, Signature, Span, SyntaxShape, - Value, + Category, Example, FromValue, IntoPipelineData, PipelineData, ShellError, Signature, Span, + SyntaxShape, Value, }; #[derive(Clone)] @@ -70,10 +70,11 @@ fn update( let ctrlc = engine_state.ctrlc.clone(); // Replace is a block, so set it up and run it instead of using it as the replacement - if let Ok(block_id) = replacement.as_block() { - let block = engine_state.get_block(block_id).clone(); + if replacement.as_block().is_ok() { + let capture_block: CaptureBlock = FromValue::from_value(&replacement)?; + let block = engine_state.get_block(capture_block.block_id).clone(); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let orig_env_vars = stack.env_vars.clone(); let orig_env_hidden = stack.env_hidden.clone(); diff --git a/crates/nu-command/src/filters/where_.rs b/crates/nu-command/src/filters/where_.rs index f1f5347da5..b10047c429 100644 --- a/crates/nu-command/src/filters/where_.rs +++ b/crates/nu-command/src/filters/where_.rs @@ -1,7 +1,7 @@ -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; -use nu_protocol::{Category, PipelineData, ShellError, Signature, SyntaxShape}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; +use nu_protocol::{Category, PipelineData, Signature, SyntaxShape}; #[derive(Clone)] pub struct Where; @@ -28,21 +28,17 @@ impl Command for Where { call: &Call, input: PipelineData, ) -> Result { - let cond = &call.positional[0]; let span = call.head; let metadata = input.metadata(); - let block_id = cond - .as_row_condition_block() - .ok_or_else(|| ShellError::TypeMismatch("expected row condition".to_owned(), span))?; + let block: CaptureBlock = call.req(engine_state, stack, 0)?; + let mut stack = stack.captures_to_stack(&block.captures); + let block = engine_state.get_block(block.block_id).clone(); let ctrlc = engine_state.ctrlc.clone(); let engine_state = engine_state.clone(); - let block = engine_state.get_block(block_id).clone(); - let mut stack = stack.collect_captures(&block.captures); - input .filter( move |value| { diff --git a/crates/nu-command/src/system/benchmark.rs b/crates/nu-command/src/system/benchmark.rs index 0b2f0b13e1..bb42f2522e 100644 --- a/crates/nu-command/src/system/benchmark.rs +++ b/crates/nu-command/src/system/benchmark.rs @@ -1,8 +1,8 @@ use std::time::Instant; -use nu_engine::eval_block; +use nu_engine::{eval_block, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::engine::{CaptureBlock, Command, EngineState, Stack}; use nu_protocol::{Category, IntoPipelineData, PipelineData, Signature, SyntaxShape, Value}; #[derive(Clone)] @@ -34,12 +34,10 @@ impl Command for Benchmark { call: &Call, _input: PipelineData, ) -> Result { - let block = call.positional[0] - .as_block() - .expect("internal error: expected block"); - let block = engine_state.get_block(block); + let capture_block: CaptureBlock = call.req(engine_state, stack, 0)?; + let block = engine_state.get_block(capture_block.block_id); - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.captures_to_stack(&capture_block.captures); let start_time = Instant::now(); eval_block( engine_state, diff --git a/crates/nu-engine/src/env.rs b/crates/nu-engine/src/env.rs index c75cba07da..39faad40ed 100644 --- a/crates/nu-engine/src/env.rs +++ b/crates/nu-engine/src/env.rs @@ -40,7 +40,7 @@ pub fn convert_env_values( let block = engine_state.get_block(block_id); if let Some(var) = block.signature.get_positional(0) { - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.gather_captures(&block.captures); if let Some(var_id) = &var.var_id { stack.add_var(*var_id, val.clone()); } @@ -92,7 +92,7 @@ pub fn env_to_string( if let Some(var) = block.signature.get_positional(0) { let val_span = value.span()?; - let mut stack = stack.collect_captures(&block.captures); + let mut stack = stack.gather_captures(&block.captures); if let Some(var_id) = &var.var_id { stack.add_var(*var_id, value); diff --git a/crates/nu-engine/src/eval.rs b/crates/nu-engine/src/eval.rs index ac81b21cc2..c57d45eb1b 100644 --- a/crates/nu-engine/src/eval.rs +++ b/crates/nu-engine/src/eval.rs @@ -1,4 +1,5 @@ use std::cmp::Ordering; +use std::collections::HashMap; use std::io::Write; use nu_protocol::ast::{Block, Call, Expr, Expression, Operator, Statement}; @@ -40,7 +41,7 @@ fn eval_call( } else if let Some(block_id) = decl.get_block_id() { let block = engine_state.get_block(block_id); - let mut callee_stack = caller_stack.collect_captures(&block.captures); + let mut callee_stack = caller_stack.gather_captures(&block.captures); for (param_idx, param) in decl .signature() @@ -286,7 +287,7 @@ pub fn eval_expression( Operator::Pow => lhs.pow(op_span, &rhs), } } - Expr::RowCondition(block_id) | Expr::Subexpression(block_id) => { + Expr::Subexpression(block_id) => { let block = engine_state.get_block(*block_id); // FIXME: protect this collect with ctrl-c @@ -295,10 +296,22 @@ pub fn eval_expression( .into_value(expr.span), ) } - Expr::Block(block_id) => Ok(Value::Block { - val: *block_id, - span: expr.span, - }), + Expr::RowCondition(block_id) | Expr::Block(block_id) => { + let mut captures = HashMap::new(); + let block = engine_state.get_block(*block_id); + + for var_id in &block.captures { + captures.insert( + *var_id, + stack.get_var(*var_id)?, //.map_err(|_| ShellError::VariableNotFoundAtRuntime(expr.span))?, + ); + } + Ok(Value::Block { + val: *block_id, + captures, + span: expr.span, + }) + } Expr::List(x) => { let mut output = vec![]; for expr in x { diff --git a/crates/nu-parser/src/flatten.rs b/crates/nu-parser/src/flatten.rs index 8b61f6cc71..1f10f4d1d9 100644 --- a/crates/nu-parser/src/flatten.rs +++ b/crates/nu-parser/src/flatten.rs @@ -95,7 +95,7 @@ pub fn flatten_expression( output.extend(flatten_expression(working_set, rhs)); output } - Expr::Block(block_id) => { + Expr::Block(block_id) | Expr::RowCondition(block_id) | Expr::Subexpression(block_id) => { let outer_span = expr.span; let mut output = vec![]; @@ -385,9 +385,6 @@ pub fn flatten_expression( Expr::String(_) => { vec![(expr.span, FlatShape::String)] } - Expr::RowCondition(block_id) | Expr::Subexpression(block_id) => { - flatten_block(working_set, working_set.get_block(*block_id)) - } Expr::Table(headers, cells) => { let outer_span = expr.span; let mut last_end = outer_span.start; diff --git a/crates/nu-parser/src/parse_keywords.rs b/crates/nu-parser/src/parse_keywords.rs index 0147d898b3..b170c8438e 100644 --- a/crates/nu-parser/src/parse_keywords.rs +++ b/crates/nu-parser/src/parse_keywords.rs @@ -5,16 +5,16 @@ use nu_protocol::{ Pipeline, Statement, }, engine::StateWorkingSet, - span, Exportable, Overlay, Span, SyntaxShape, Type, CONFIG_VARIABLE_ID, + span, Exportable, Overlay, PositionalArg, Span, SyntaxShape, Type, CONFIG_VARIABLE_ID, }; use std::collections::{HashMap, HashSet}; use crate::{ lex, lite_parse, parser::{ - check_call, check_name, garbage, garbage_statement, parse, parse_block_expression, - parse_internal_call, parse_multispan_value, parse_signature, parse_string, - parse_var_with_opt_type, trim_quotes, + check_call, check_name, find_captures_in_block, garbage, garbage_statement, parse, + parse_block_expression, parse_internal_call, parse_multispan_value, parse_signature, + parse_string, parse_var_with_opt_type, trim_quotes, }, ParseError, }; @@ -57,6 +57,121 @@ pub fn parse_def_predecl(working_set: &mut StateWorkingSet, spans: &[Span]) -> O None } +pub fn parse_for( + working_set: &mut StateWorkingSet, + spans: &[Span], +) -> (Statement, Option) { + // Checking that the function is used with the correct name + // Maybe this is not necessary but it is a sanity check + if working_set.get_span_contents(spans[0]) != b"for" { + return ( + garbage_statement(spans), + Some(ParseError::UnknownState( + "internal error: Wrong call name for 'for' function".into(), + span(spans), + )), + ); + } + + // Parsing the spans and checking that they match the register signature + // Using a parsed call makes more sense than checking for how many spans are in the call + // Also, by creating a call, it can be checked if it matches the declaration signature + let (call, call_span) = match working_set.find_decl(b"for") { + None => { + return ( + garbage_statement(spans), + Some(ParseError::UnknownState( + "internal error: def declaration not found".into(), + span(spans), + )), + ) + } + Some(decl_id) => { + working_set.enter_scope(); + let (call, mut err) = parse_internal_call(working_set, spans[0], &spans[1..], decl_id); + working_set.exit_scope(); + + let call_span = span(spans); + let decl = working_set.get_decl(decl_id); + let sig = decl.signature(); + + // Let's get our block and make sure it has the right signature + if let Some(arg) = call.positional.get(2) { + match arg { + Expression { + expr: Expr::Block(block_id), + .. + } + | Expression { + expr: Expr::RowCondition(block_id), + .. + } => { + let block = working_set.get_block_mut(*block_id); + + block.signature = Box::new(sig.clone()); + } + _ => {} + } + } + + err = check_call(call_span, &sig, &call).or(err); + if err.is_some() || call.has_flag("help") { + return ( + Statement::Pipeline(Pipeline::from_vec(vec![Expression { + expr: Expr::Call(call), + span: call_span, + ty: Type::Unknown, + custom_completion: None, + }])), + err, + ); + } + + (call, call_span) + } + }; + + // All positional arguments must be in the call positional vector by this point + let var_decl = call.positional.get(0).expect("for call already checked"); + let block = call.positional.get(2).expect("for call already checked"); + + let error = None; + if let (Some(var_id), Some(block_id)) = (&var_decl.as_var(), block.as_block()) { + let block = working_set.get_block_mut(block_id); + + block.signature.required_positional.insert( + 0, + PositionalArg { + name: String::new(), + desc: String::new(), + shape: SyntaxShape::Any, + var_id: Some(*var_id), + }, + ); + + let block = working_set.get_block(block_id); + + // Now that we have a signature for the block, we know more about what variables + // will come into scope as params. Because of this, we need to recalculated what + // variables this block will capture from the outside. + let mut seen = vec![]; + let captures = find_captures_in_block(working_set, block, &mut seen); + + let mut block = working_set.get_block_mut(block_id); + block.captures = captures; + } + + ( + Statement::Pipeline(Pipeline::from_vec(vec![Expression { + expr: Expr::Call(call), + span: call_span, + ty: Type::Unknown, + custom_completion: None, + }])), + error, + ) +} + pub fn parse_def( working_set: &mut StateWorkingSet, spans: &[Span], @@ -93,8 +208,28 @@ pub fn parse_def( let call_span = span(spans); let decl = working_set.get_decl(decl_id); + let sig = decl.signature(); - err = check_call(call_span, &decl.signature(), &call).or(err); + // Let's get our block and make sure it has the right signature + if let Some(arg) = call.positional.get(2) { + match arg { + Expression { + expr: Expr::Block(block_id), + .. + } + | Expression { + expr: Expr::RowCondition(block_id), + .. + } => { + let block = working_set.get_block_mut(*block_id); + + block.signature = Box::new(sig.clone()); + } + _ => {} + } + } + + err = check_call(call_span, &sig, &call).or(err); if err.is_some() || call.has_flag("help") { return ( Statement::Pipeline(Pipeline::from_vec(vec![Expression { @@ -124,7 +259,22 @@ pub fn parse_def( let declaration = working_set.get_decl_mut(decl_id); signature.name = name.clone(); - *declaration = signature.into_block_command(block_id); + + *declaration = signature.clone().into_block_command(block_id); + + let mut block = working_set.get_block_mut(block_id); + block.signature = signature; + + let block = working_set.get_block(block_id); + + // Now that we have a signature for the block, we know more about what variables + // will come into scope as params. Because of this, we need to recalculated what + // variables this block will capture from the outside. + let mut seen = vec![]; + let captures = find_captures_in_block(working_set, block, &mut seen); + + let mut block = working_set.get_block_mut(block_id); + block.captures = captures; } else { error = error.or_else(|| { Some(ParseError::InternalError( diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index 72dd36f22f..7291a741fc 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -1,6 +1,6 @@ use crate::{ lex, lite_parse, - parse_keywords::parse_source, + parse_keywords::{parse_for, parse_source}, type_check::{math_result_type, type_compatible}, LiteBlock, ParseError, Token, TokenContents, }; @@ -13,7 +13,7 @@ use nu_protocol::{ }, engine::StateWorkingSet, span, Flag, PositionalArg, Signature, Span, Spanned, SyntaxShape, Type, Unit, VarId, - CONFIG_VARIABLE_ID, + CONFIG_VARIABLE_ID, ENV_VARIABLE_ID, IN_VARIABLE_ID, }; use crate::parse_keywords::{ @@ -3402,6 +3402,7 @@ pub fn parse_statement( match name { b"def" => parse_def(working_set, spans), b"let" => parse_let(working_set, spans), + b"for" => parse_for(working_set, spans), b"alias" => parse_alias(working_set, spans), b"module" => parse_module(working_set, spans), b"use" => parse_use(working_set, spans), @@ -3577,13 +3578,15 @@ pub fn parse_block( (block, error) } -fn find_captures_in_block( +pub fn find_captures_in_block( working_set: &StateWorkingSet, block: &Block, seen: &mut Vec, ) -> Vec { let mut output = vec![]; + // println!("sig: {:#?}", block.signature); + for flag in &block.signature.named { if let Some(var_id) = flag.var_id { seen.push(var_id); @@ -3654,6 +3657,13 @@ pub fn find_captures_in_expr( } Expr::Bool(_) => {} Expr::Call(call) => { + let decl = working_set.get_decl(call.decl_id); + if let Some(block_id) = decl.get_block_id() { + let block = working_set.get_block(block_id); + let result = find_captures_in_block(working_set, block, seen); + output.extend(&result); + } + for named in &call.named { if let Some(arg) = &named.1 { let result = find_captures_in_expr(working_set, arg, seen); @@ -3715,7 +3725,30 @@ pub fn find_captures_in_expr( output.extend(&find_captures_in_expr(working_set, field_value, seen)); } } - Expr::Signature(_) => {} + Expr::Signature(sig) => { + // println!("Signature found! Adding var_ids"); + // Something with a declaration, similar to a var decl, will introduce more VarIds into the stack at eval + for pos in &sig.required_positional { + if let Some(var_id) = pos.var_id { + seen.push(var_id); + } + } + for pos in &sig.optional_positional { + if let Some(var_id) = pos.var_id { + seen.push(var_id); + } + } + if let Some(rest) = &sig.rest_positional { + if let Some(var_id) = rest.var_id { + seen.push(var_id); + } + } + for named in &sig.named { + if let Some(var_id) = named.var_id { + seen.push(var_id); + } + } + } Expr::String(_) => {} Expr::StringInterpolation(exprs) => { for expr in exprs { @@ -3745,7 +3778,7 @@ pub fn find_captures_in_expr( output.extend(&result); } Expr::Var(var_id) => { - if !seen.contains(var_id) { + if (*var_id > ENV_VARIABLE_ID || *var_id == IN_VARIABLE_ID) && !seen.contains(var_id) { output.push(*var_id); } } diff --git a/crates/nu-protocol/src/ast/expression.rs b/crates/nu-protocol/src/ast/expression.rs index 6a424e4612..c4a0e520a0 100644 --- a/crates/nu-protocol/src/ast/expression.rs +++ b/crates/nu-protocol/src/ast/expression.rs @@ -112,6 +112,10 @@ impl Expression { Expr::Block(block_id) => { let block = working_set.get_block(*block_id); + if block.captures.contains(&IN_VARIABLE_ID) { + return true; + } + if let Some(Statement::Pipeline(pipeline)) = block.stmts.get(0) { match pipeline.expressions.get(0) { Some(expr) => expr.has_in_variable(working_set), diff --git a/crates/nu-protocol/src/engine/capture_block.rs b/crates/nu-protocol/src/engine/capture_block.rs new file mode 100644 index 0000000000..447c33e5a3 --- /dev/null +++ b/crates/nu-protocol/src/engine/capture_block.rs @@ -0,0 +1,9 @@ +use std::collections::HashMap; + +use crate::{BlockId, Value, VarId}; + +#[derive(Clone, Debug)] +pub struct CaptureBlock { + pub block_id: BlockId, + pub captures: HashMap, +} diff --git a/crates/nu-protocol/src/engine/engine_state.rs b/crates/nu-protocol/src/engine/engine_state.rs index 9c4ac5ba24..9c0fdf8ded 100644 --- a/crates/nu-protocol/src/engine/engine_state.rs +++ b/crates/nu-protocol/src/engine/engine_state.rs @@ -154,6 +154,7 @@ pub const SCOPE_VARIABLE_ID: usize = 1; pub const IN_VARIABLE_ID: usize = 2; pub const CONFIG_VARIABLE_ID: usize = 3; pub const ENV_VARIABLE_ID: usize = 4; +// NOTE: If you add more to this list, make sure to update the > checks based on the last in the list impl EngineState { pub fn new() -> Self { diff --git a/crates/nu-protocol/src/engine/mod.rs b/crates/nu-protocol/src/engine/mod.rs index 81228717b8..296578b417 100644 --- a/crates/nu-protocol/src/engine/mod.rs +++ b/crates/nu-protocol/src/engine/mod.rs @@ -1,9 +1,11 @@ mod call_info; +mod capture_block; mod command; mod engine_state; mod stack; pub use call_info::*; +pub use capture_block::*; pub use command::*; pub use engine_state::*; pub use stack::*; diff --git a/crates/nu-protocol/src/engine/stack.rs b/crates/nu-protocol/src/engine/stack.rs index 35f65d2029..52919a5646 100644 --- a/crates/nu-protocol/src/engine/stack.rs +++ b/crates/nu-protocol/src/engine/stack.rs @@ -62,7 +62,10 @@ impl Stack { return Ok(v.clone()); } - Err(ShellError::NushellFailed("variable not found".into())) + Err(ShellError::NushellFailed(format!( + "variable (var_id: {}) not found", + var_id + ))) } pub fn add_var(&mut self, var_id: VarId, value: Value) { @@ -80,7 +83,24 @@ impl Stack { } } - pub fn collect_captures(&self, captures: &[VarId]) -> Stack { + pub fn captures_to_stack(&self, captures: &HashMap) -> Stack { + let mut output = Stack::new(); + + output.vars = captures.clone(); + + // FIXME: this is probably slow + output.env_vars = self.env_vars.clone(); + output.env_vars.push(HashMap::new()); + + let config = self + .get_var(CONFIG_VARIABLE_ID) + .expect("internal error: config is missing"); + output.vars.insert(CONFIG_VARIABLE_ID, config); + + output + } + + pub fn gather_captures(&self, captures: &[VarId]) -> Stack { let mut output = Stack::new(); for capture in captures { diff --git a/crates/nu-protocol/src/value/from_value.rs b/crates/nu-protocol/src/value/from_value.rs index 2864054f4b..ea3fd7b0f4 100644 --- a/crates/nu-protocol/src/value/from_value.rs +++ b/crates/nu-protocol/src/value/from_value.rs @@ -6,6 +6,7 @@ use std::str::FromStr; use chrono::{DateTime, FixedOffset}; // use nu_path::expand_path; use crate::ast::{CellPath, PathMember}; +use crate::engine::CaptureBlock; use crate::ShellError; use crate::{Range, Spanned, Value}; @@ -351,3 +352,42 @@ impl FromValue for Vec { } } } + +impl FromValue for CaptureBlock { + fn from_value(v: &Value) -> Result { + match v { + Value::Block { val, captures, .. } => Ok(CaptureBlock { + block_id: *val, + captures: captures.clone(), + }), + v => Err(ShellError::CantConvert( + "Block".into(), + v.get_type().to_string(), + v.span()?, + )), + } + } +} + +impl FromValue for Spanned { + fn from_value(v: &Value) -> Result { + match v { + Value::Block { + val, + captures, + span, + } => Ok(Spanned { + item: CaptureBlock { + block_id: *val, + captures: captures.clone(), + }, + span: *span, + }), + v => Err(ShellError::CantConvert( + "Block".into(), + v.get_type().to_string(), + v.span()?, + )), + } + } +} diff --git a/crates/nu-protocol/src/value/mod.rs b/crates/nu-protocol/src/value/mod.rs index 0b17f766b7..d5934fc485 100644 --- a/crates/nu-protocol/src/value/mod.rs +++ b/crates/nu-protocol/src/value/mod.rs @@ -22,7 +22,7 @@ use std::path::PathBuf; use std::{cmp::Ordering, fmt::Debug}; use crate::ast::{CellPath, PathMember}; -use crate::{did_you_mean, span, BlockId, Config, Span, Spanned, Type}; +use crate::{did_you_mean, span, BlockId, Config, Span, Spanned, Type, VarId}; use crate::ast::Operator; pub use custom_value::CustomValue; @@ -75,6 +75,7 @@ pub enum Value { }, Block { val: BlockId, + captures: HashMap, span: Span, }, Nothing { @@ -141,8 +142,13 @@ impl Clone for Value { vals: vals.clone(), span: *span, }, - Value::Block { val, span } => Value::Block { + Value::Block { + val, + captures, + span, + } => Value::Block { val: *val, + captures: captures.clone(), span: *span, }, Value::Nothing { span } => Value::Nothing { span: *span }, diff --git a/src/tests/test_engine.rs b/src/tests/test_engine.rs index 240f2d1cb0..3325d2f8b8 100644 --- a/src/tests/test_engine.rs +++ b/src/tests/test_engine.rs @@ -119,3 +119,27 @@ fn missing_flags_are_nothing4() -> TestResult { "10003", ) } + +#[test] +fn proper_variable_captures() -> TestResult { + run_test( + r#"def foo [x] { let y = 100; { $y + $x } }; do (foo 23)"#, + "123", + ) +} + +#[test] +fn proper_variable_captures_with_calls() -> TestResult { + run_test( + r#"def foo [] { let y = 60; def bar [] { $y }; { bar } }; do (foo)"#, + "60", + ) +} + +#[test] +fn proper_variable_captures_with_nesting() -> TestResult { + run_test( + r#"def foo [x] { let z = 100; def bar [y] { $y - $x + $z } ; { |z| bar $z } }; do (foo 11) 13"#, + "102", + ) +}