diff --git a/src/eval.rs b/src/eval.rs index 6531c3b75c..b6bcc5e00b 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; -use crate::{parser::Operator, Block, Call, Expr, Expression, ParserState, Span, Statement, VarId}; +use crate::{ + parser::Operator, Block, BlockId, Call, Expr, Expression, ParserState, Span, Statement, VarId, +}; #[derive(Debug)] pub enum ShellError { @@ -12,6 +14,9 @@ pub enum ShellError { #[derive(Debug, Clone)] pub enum Value { Int { val: i64, span: Span }, + String { val: String, span: Span }, + List(Vec), + Block(BlockId), Unknown, } impl Value { @@ -38,14 +43,21 @@ impl Stack { pub fn get_var(&self, var_id: VarId) -> Result { match self.vars.get(&var_id) { Some(v) => Ok(v.clone()), - _ => Err(ShellError::InternalError("variable not found".into())), + _ => { + println!("var_id: {}", var_id); + Err(ShellError::InternalError("variable not found".into())) + } } } + + pub fn add_var(&mut self, var_id: VarId, value: Value) { + self.vars.insert(var_id, value); + } } pub fn eval_operator( - state: &State, - stack: &mut Stack, + _state: &State, + _stack: &mut Stack, op: &Expression, ) -> Result { match op { @@ -59,8 +71,19 @@ pub fn eval_operator( fn eval_call(state: &State, stack: &mut Stack, call: &Call) -> Result { let decl = state.parser_state.get_decl(call.decl_id); - if let Some(block_id) = decl.body { + for (arg, param) in call + .positional + .iter() + .zip(decl.signature.required_positional.iter()) + { + let result = eval_expression(state, stack, arg)?; + let var_id = param + .var_id + .expect("internal error: all custom parameters must have var_ids"); + + stack.add_var(var_id, result); + } let block = state.parser_state.get_block(block_id); eval_block(state, stack, block) } else { @@ -98,13 +121,22 @@ pub fn eval_expression( eval_block(state, stack, block) } - Expr::Block(_) => Err(ShellError::Unsupported(expr.span)), - Expr::List(_) => Err(ShellError::Unsupported(expr.span)), + Expr::Block(block_id) => Ok(Value::Block(*block_id)), + Expr::List(x) => { + let mut output = vec![]; + for expr in x { + output.push(eval_expression(state, stack, expr)?); + } + Ok(Value::List(output)) + } Expr::Table(_, _) => Err(ShellError::Unsupported(expr.span)), - Expr::Literal(_) => Err(ShellError::Unsupported(expr.span)), - Expr::String(_) => Err(ShellError::Unsupported(expr.span)), - Expr::Signature(_) => Err(ShellError::Unsupported(expr.span)), - Expr::Garbage => Err(ShellError::Unsupported(expr.span)), + Expr::Literal(_) => Ok(Value::Unknown), + Expr::String(s) => Ok(Value::String { + val: s.clone(), + span: expr.span, + }), + Expr::Signature(_) => Ok(Value::Unknown), + Expr::Garbage => Ok(Value::Unknown), } } diff --git a/src/flatten.rs b/src/flatten.rs index 5f228132bb..f994f6778c 100644 --- a/src/flatten.rs +++ b/src/flatten.rs @@ -26,6 +26,7 @@ impl<'a> ParserWorkingSet<'a> { match stmt { Statement::Expression(expr) => self.flatten_expression(expr), Statement::Pipeline(pipeline) => self.flatten_pipeline(pipeline), + _ => vec![], } } diff --git a/src/main.rs b/src/main.rs index 5f2b226618..abcc8d2767 100644 --- a/src/main.rs +++ b/src/main.rs @@ -65,6 +65,14 @@ fn main() -> std::io::Result<()> { // .named("--jazz", SyntaxShape::Int, "jazz!!", Some('j')) // .switch("--rock", "rock!!", Some('r')); // working_set.add_decl(sig.into()); + let sig = Signature::build("exit"); + working_set.add_decl(sig.into()); + let sig = Signature::build("vars"); + working_set.add_decl(sig.into()); + let sig = Signature::build("decls"); + working_set.add_decl(sig.into()); + let sig = Signature::build("blocks"); + working_set.add_decl(sig.into()); let sig = Signature::build("add"); working_set.add_decl(sig.into()); @@ -89,7 +97,7 @@ fn main() -> std::io::Result<()> { let file = std::fs::read(&path)?; - let (block, err) = working_set.parse_file(&path, &file, false); + let (block, _err) = working_set.parse_file(&path, &file, false); println!("{}", block.len()); // println!("{:#?}", output); // println!("error: {:?}", err); @@ -130,6 +138,15 @@ fn main() -> std::io::Result<()> { Signal::Success(s) => { if s.trim() == "exit" { break; + } else if s.trim() == "vars" { + parser_state.borrow().print_vars(); + continue; + } else if s.trim() == "decls" { + parser_state.borrow().print_decls(); + continue; + } else if s.trim() == "blocks" { + parser_state.borrow().print_blocks(); + continue; } // println!("input: '{}'", s); diff --git a/src/parse_error.rs b/src/parse_error.rs index d126eca336..5ff90ff60b 100644 --- a/src/parse_error.rs +++ b/src/parse_error.rs @@ -13,6 +13,7 @@ pub enum ParseError { UnknownCommand(Span), NonUtf8(Span), UnknownFlag(Span), + UnknownType(Span), MissingFlagParam(Span), ShortFlagBatchCantTakeArg(Span), MissingPositional(String, Span), diff --git a/src/parser.rs b/src/parser.rs index fda1bcce3c..2346c06efc 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -78,6 +78,38 @@ pub enum SyntaxShape { Expression, } +impl SyntaxShape { + pub fn to_type(&self) -> Type { + match self { + SyntaxShape::Any => Type::Unknown, + SyntaxShape::Block => Type::Block, + SyntaxShape::ColumnPath => Type::Unknown, + SyntaxShape::Duration => Type::Duration, + SyntaxShape::Expression => Type::Unknown, + SyntaxShape::FilePath => Type::FilePath, + SyntaxShape::Filesize => Type::Filesize, + SyntaxShape::FullColumnPath => Type::Unknown, + SyntaxShape::GlobPattern => Type::String, + SyntaxShape::Int => Type::Int, + SyntaxShape::List(x) => { + let contents = x.to_type(); + Type::List(Box::new(contents)) + } + SyntaxShape::Literal(..) => Type::Unknown, + SyntaxShape::MathExpression => Type::Unknown, + SyntaxShape::Number => Type::Number, + SyntaxShape::Operator => Type::Unknown, + SyntaxShape::Range => Type::Unknown, + SyntaxShape::RowCondition => Type::Bool, + SyntaxShape::Signature => Type::Unknown, + SyntaxShape::String => Type::String, + SyntaxShape::Table => Type::Table, + SyntaxShape::VarWithOptType => Type::Unknown, + SyntaxShape::Variable => Type::Unknown, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum Operator { Equal, @@ -148,13 +180,14 @@ pub enum Expr { pub struct Expression { pub expr: Expr, pub span: Span, + pub ty: Type, } impl Expression { pub fn garbage(span: Span) -> Expression { Expression { expr: Expr::Garbage, span, - //ty: Type::Unknown, + ty: Type::Unknown, } } pub fn precedence(&self) -> usize { @@ -272,6 +305,7 @@ pub struct VarDecl { #[derive(Debug, Clone)] pub enum Statement { + Declaration(DeclId), Pipeline(Pipeline), Expression(Expression), } @@ -359,6 +393,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::ExternalCall(name, args), span: span(spans), + ty: Type::Unknown, }, None, ) @@ -533,6 +568,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Literal(literal), span: arg_span, + ty: Type::Unknown, }, error, ) @@ -678,6 +714,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Call(call), span: span(spans), + ty: Type::Unknown, // FIXME }, err, ) @@ -693,6 +730,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Int(v), span, + ty: Type::Int, }, None, ) @@ -708,6 +746,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Int(v), span, + ty: Type::Int, }, None, ) @@ -723,6 +762,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Int(v), span, + ty: Type::Int, }, None, ) @@ -737,6 +777,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Int(x), span, + ty: Type::Int, }, None, ) @@ -767,6 +808,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Var(var_id), span, + ty: self.get_variable(var_id).clone(), }, None, ) @@ -784,6 +826,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Var(id), span, + ty: self.get_variable(id).clone(), }, None, ) @@ -796,6 +839,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Var(id), span, + ty: Type::Unknown, }, None, ) @@ -849,6 +893,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Subexpression(block_id), span, + ty: Type::Unknown, // FIXME }, error, ) @@ -862,6 +907,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::String(token), span, + ty: Type::String, }, None, ) @@ -874,8 +920,8 @@ impl<'a> ParserWorkingSet<'a> { } //TODO: Handle error case - pub fn parse_shape_name(&self, bytes: &[u8]) -> SyntaxShape { - match bytes { + pub fn parse_shape_name(&self, bytes: &[u8], span: Span) -> (SyntaxShape, Option) { + let result = match bytes { b"any" => SyntaxShape::Any, b"string" => SyntaxShape::String, b"column-path" => SyntaxShape::ColumnPath, @@ -891,8 +937,10 @@ impl<'a> ParserWorkingSet<'a> { b"variable" => SyntaxShape::Variable, b"signature" => SyntaxShape::Signature, b"expr" => SyntaxShape::Expression, - _ => SyntaxShape::Any, - } + _ => return (SyntaxShape::Any, Some(ParseError::UnknownType(span))), + }; + + (result, None) } pub fn parse_type(&self, bytes: &[u8]) -> Type { @@ -919,12 +967,13 @@ impl<'a> ParserWorkingSet<'a> { let ty = self.parse_type(type_bytes); *spans_idx += 1; - let id = self.add_variable(bytes[0..(bytes.len() - 1)].to_vec(), ty); + let id = self.add_variable(bytes[0..(bytes.len() - 1)].to_vec(), ty.clone()); ( Expression { expr: Expr::Var(id), span: span(&spans[*spans_idx - 2..*spans_idx]), + ty, }, None, ) @@ -935,6 +984,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Var(id), span: spans[*spans_idx], + ty: Type::Unknown, }, Some(ParseError::MissingType(spans[*spans_idx])), ) @@ -947,6 +997,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Var(id), span: span(&spans[*spans_idx - 1..*spans_idx]), + ty: Type::Unknown, }, None, ) @@ -1023,18 +1074,26 @@ impl<'a> ParserWorkingSet<'a> { ParseMode::ArgMode => { if contents.starts_with(b"--") && contents.len() > 2 { // Long flag - let flags: Vec<_> = contents.split(|x| x == &b'(').collect(); + let flags: Vec<_> = contents + .split(|x| x == &b'(') + .map(|x| x.to_vec()) + .collect(); + + let long = String::from_utf8_lossy(&flags[0]).to_string(); + let variable_name = flags[0][2..].to_vec(); + let var_id = self.add_variable(variable_name, Type::Unknown); if flags.len() == 1 { args.push(Arg::Flag(Flag { arg: None, desc: String::new(), - long: String::from_utf8_lossy(flags[0]).to_string(), + long, short: None, required: false, + var_id: Some(var_id), })); } else { - let short_flag = flags[1]; + let short_flag = &flags[1]; let short_flag = if !short_flag.starts_with(b"-") || !short_flag.ends_with(b")") { @@ -1048,16 +1107,21 @@ impl<'a> ParserWorkingSet<'a> { }; let short_flag = - String::from_utf8_lossy(short_flag).to_string(); + String::from_utf8_lossy(&short_flag).to_string(); let chars: Vec = short_flag.chars().collect(); + let long = String::from_utf8_lossy(&flags[0]).to_string(); + let variable_name = flags[0][2..].to_vec(); + let var_id = + self.add_variable(variable_name, Type::Unknown); if chars.len() == 1 { args.push(Arg::Flag(Flag { arg: None, desc: String::new(), - long: String::from_utf8_lossy(flags[0]).to_string(), + long, short: Some(chars[0]), required: false, + var_id: Some(var_id), })); } else { error = error.or(Some(ParseError::Mismatch( @@ -1086,14 +1150,22 @@ impl<'a> ParserWorkingSet<'a> { long: String::new(), short: None, required: false, + var_id: None, })); } else { + let mut encoded_var_name = vec![0u8; 4]; + let len = chars[0].encode_utf8(&mut encoded_var_name).len(); + let variable_name = encoded_var_name[0..len].to_vec(); + let var_id = + self.add_variable(variable_name, Type::Unknown); + args.push(Arg::Flag(Flag { arg: None, desc: String::new(), long: String::new(), short: Some(chars[0]), required: false, + var_id: Some(var_id), })); } } else if contents.starts_with(b"(-") { @@ -1140,24 +1212,35 @@ impl<'a> ParserWorkingSet<'a> { } } else { if contents.ends_with(b"?") { - let contents = &contents[..(contents.len() - 1)]; + let contents: Vec<_> = + contents[..(contents.len() - 1)].into(); + let name = String::from_utf8_lossy(&contents).to_string(); + + let var_id = + self.add_variable(contents.into(), Type::Unknown); // Positional arg, optional args.push(Arg::Positional( PositionalArg { desc: String::new(), - name: String::from_utf8_lossy(contents).to_string(), + name, shape: SyntaxShape::Any, + var_id: Some(var_id), }, false, )) } else { + let name = String::from_utf8_lossy(contents).to_string(); + let contents_vec = contents.to_vec(); + let var_id = self.add_variable(contents_vec, Type::Unknown); + // Positional arg, required args.push(Arg::Positional( PositionalArg { desc: String::new(), - name: String::from_utf8_lossy(contents).to_string(), + name, shape: SyntaxShape::Any, + var_id: Some(var_id), }, true, )) @@ -1166,13 +1249,21 @@ impl<'a> ParserWorkingSet<'a> { } ParseMode::TypeMode => { if let Some(last) = args.last_mut() { - let syntax_shape = self.parse_shape_name(contents); + let (syntax_shape, err) = self.parse_shape_name(contents, span); + error = error.or(err); //TODO check if we're replacing one already match last { - Arg::Positional(PositionalArg { shape, .. }, ..) => { + Arg::Positional( + PositionalArg { shape, var_id, .. }, + .., + ) => { + self.set_variable_type(var_id.expect("internal error: all custom parameters must have var_ids"), syntax_shape.to_type()); *shape = syntax_shape; } - Arg::Flag(Flag { arg, .. }) => *arg = Some(syntax_shape), + Arg::Flag(Flag { arg, var_id, .. }) => { + self.set_variable_type(var_id.expect("internal error: all custom parameters must have var_ids"), syntax_shape.to_type()); + *arg = Some(syntax_shape) + } } } parse_mode = ParseMode::ArgMode; @@ -1242,6 +1333,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Signature(sig), span, + ty: Type::Unknown, }, error, ) @@ -1310,6 +1402,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::List(args), span, + ty: Type::List(Box::new(Type::Unknown)), // FIXME }, error, ) @@ -1354,6 +1447,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::List(vec![]), span, + ty: Type::Table, }, None, ), @@ -1393,6 +1487,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Table(table_headers, rows), span, + ty: Type::Table, }, error, ) @@ -1448,6 +1543,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Block(block_id), span, + ty: Type::Block, }, error, ) @@ -1514,6 +1610,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Literal(literal), span, + ty: Type::Unknown, }, None, ) @@ -1633,6 +1730,7 @@ impl<'a> ParserWorkingSet<'a> { Expression { expr: Expr::Operator(operator), span, + ty: Type::Unknown, }, None, ) @@ -1686,20 +1784,24 @@ impl<'a> ParserWorkingSet<'a> { while expr_stack.len() > 1 { // Collapse the right associated operations first // so that we can get back to a stack with a lower precedence - let rhs = expr_stack + let mut rhs = expr_stack .pop() .expect("internal error: expression stack empty"); - let op = expr_stack + let mut op = expr_stack .pop() .expect("internal error: expression stack empty"); - let lhs = expr_stack + let mut lhs = expr_stack .pop() .expect("internal error: expression stack empty"); + let (result_ty, err) = self.math_result_type(&mut lhs, &mut op, &mut rhs); + error = error.or(err); + let op_span = span(&[lhs.span, rhs.span]); expr_stack.push(Expression { expr: Expr::BinaryOp(Box::new(lhs), Box::new(op), Box::new(rhs)), span: op_span, + ty: result_ty, }); } } @@ -1712,20 +1814,24 @@ impl<'a> ParserWorkingSet<'a> { } while expr_stack.len() != 1 { - let rhs = expr_stack + let mut rhs = expr_stack .pop() .expect("internal error: expression stack empty"); - let op = expr_stack + let mut op = expr_stack .pop() .expect("internal error: expression stack empty"); - let lhs = expr_stack + let mut lhs = expr_stack .pop() .expect("internal error: expression stack empty"); + let (result_ty, err) = self.math_result_type(&mut lhs, &mut op, &mut rhs); + error = error.or(err); + let binary_op_span = span(&[lhs.span, rhs.span]); expr_stack.push(Expression { expr: Expr::BinaryOp(Box::new(lhs), Box::new(op), Box::new(rhs)), span: binary_op_span, + ty: result_ty, }); } @@ -1736,6 +1842,58 @@ impl<'a> ParserWorkingSet<'a> { (output, error) } + pub fn math_result_type( + &self, + lhs: &mut Expression, + op: &mut Expression, + rhs: &mut Expression, + ) -> (Type, Option) { + match &op.expr { + Expr::Operator(operator) => match operator { + Operator::Plus => match (&lhs.ty, &rhs.ty) { + (Type::Int, Type::Int) => (Type::Int, None), + (Type::Unknown, _) => (Type::Unknown, None), + (_, Type::Unknown) => (Type::Unknown, None), + (Type::Int, _) => { + *rhs = Expression::garbage(rhs.span); + ( + Type::Unknown, + Some(ParseError::Mismatch("int".into(), rhs.span)), + ) + } + (_, Type::Int) => { + *lhs = Expression::garbage(lhs.span); + ( + Type::Unknown, + Some(ParseError::Mismatch("int".into(), lhs.span)), + ) + } + _ => { + *op = Expression::garbage(op.span); + ( + Type::Unknown, + Some(ParseError::Mismatch("math".into(), op.span)), + ) + } + }, + _ => { + *op = Expression::garbage(op.span); + ( + Type::Unknown, + Some(ParseError::Mismatch("math".into(), op.span)), + ) + } + }, + _ => { + *op = Expression::garbage(op.span); + ( + Type::Unknown, + Some(ParseError::Mismatch("operator".into(), op.span)), + ) + } + } + } + pub fn parse_expression(&mut self, spans: &[Span]) -> (Expression, Option) { let bytes = self.get_span_contents(spans[0]); @@ -1772,60 +1930,69 @@ impl<'a> ParserWorkingSet<'a> { } pub fn parse_def(&mut self, spans: &[Span]) -> (Statement, Option) { + let mut error = None; let name = self.get_span_contents(spans[0]); - if name == b"def" { - if let Some(decl_id) = self.find_decl(b"def") { - let (call, call_span, err) = - self.parse_internal_call(spans[0], &spans[1..], decl_id); + if name == b"def" && spans.len() >= 4 { + //FIXME: don't use expect here + let (name_expr, err) = self.parse_string(spans[1]); + let name = name_expr + .as_string() + .expect("internal error: expected def name"); + error = error.or(err); - if err.is_some() { - return ( - Statement::Expression(Expression { - expr: Expr::Call(call), - span: call_span, - }), - err, - ); - } else { - let name = call.positional[0] - .as_string() - .expect("internal error: expected def name"); - let mut signature = call.positional[1] - .as_signature() - .expect("internal error: expected param list"); - let block_id = call.positional[2] - .as_block() - .expect("internal error: expected block"); + self.enter_scope(); + let (sig, err) = self.parse_signature(spans[2]); + let mut signature = sig + .as_signature() + .expect("internal error: expected param list"); + error = error.or(err); - signature.name = name; - let decl = Declaration { - signature, - body: Some(block_id), - }; + let (block, err) = self.parse_block_expression(spans[3]); + self.exit_scope(); - self.add_decl(decl); + let block_id = block.as_block().expect("internal error: expected block"); + error = error.or(err); - return ( - Statement::Expression(Expression { - expr: Expr::Call(call), - span: call_span, - }), - None, - ); - } - } + signature.name = name; + let decl = Declaration { + signature, + body: Some(block_id), + }; + + self.add_decl(decl); + let def_decl_id = self + .find_decl(b"def") + .expect("internal error: missing def command"); + + let call = Box::new(Call { + head: spans[0], + decl_id: def_decl_id, + positional: vec![name_expr, sig, block], + named: vec![], + }); + + ( + Statement::Expression(Expression { + expr: Expr::Call(call), + span: span(spans), + ty: Type::Unknown, + }), + error, + ) + } else { + ( + Statement::Expression(Expression { + expr: Expr::Garbage, + span: span(spans), + ty: Type::Unknown, + }), + Some(ParseError::UnknownState( + "internal error: let statement unparseable".into(), + span(spans), + )), + ) } - ( - Statement::Expression(Expression { - expr: Expr::Garbage, - span: span(spans), - }), - Some(ParseError::UnknownState( - "internal error: let statement unparseable".into(), - span(spans), - )), - ) } pub fn parse_let(&mut self, spans: &[Span]) -> (Statement, Option) { @@ -1840,6 +2007,7 @@ impl<'a> ParserWorkingSet<'a> { Statement::Expression(Expression { expr: Expr::Call(call), span: call_span, + ty: Type::Unknown, }), err, ); @@ -1849,6 +2017,7 @@ impl<'a> ParserWorkingSet<'a> { Statement::Expression(Expression { expr: Expr::Garbage, span: span(spans), + ty: Type::Unknown, }), Some(ParseError::UnknownState( "internal error: let statement unparseable".into(), diff --git a/src/parser_state.rs b/src/parser_state.rs index 0a4f08d434..8d73e7528d 100644 --- a/src/parser_state.rs +++ b/src/parser_state.rs @@ -1,4 +1,5 @@ use crate::{parser::Block, Declaration, Span}; +use core::panic; use std::collections::HashMap; #[derive(Debug)] @@ -11,9 +12,19 @@ pub struct ParserState { scope: Vec, } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] pub enum Type { Int, + Bool, + String, + Block, + ColumnPath, + Duration, + FilePath, + Filesize, + List(Box), + Number, + Table, Unknown, } @@ -89,6 +100,24 @@ impl ParserState { self.blocks.len() } + pub fn print_vars(&self) { + for var in self.vars.iter().enumerate() { + println!("var{}: {:?}", var.0, var.1); + } + } + + pub fn print_decls(&self) { + for decl in self.decls.iter().enumerate() { + println!("decl{}: {:?}", decl.0, decl.1); + } + } + + pub fn print_blocks(&self) { + for block in self.blocks.iter().enumerate() { + println!("block{}: {:?}", block.0, block.1); + } + } + pub fn get_var(&self, var_id: VarId) -> &Type { self.vars .get(var_id) @@ -319,11 +348,20 @@ impl<'a> ParserWorkingSet<'a> { last.vars.insert(name, next_id); - self.delta.vars.insert(next_id, ty); + self.delta.vars.push(ty); next_id } + pub fn set_variable_type(&mut self, var_id: VarId, ty: Type) { + let num_permanent_vars = self.permanent_state.num_vars(); + if var_id < num_permanent_vars { + panic!("Internal error: attempted to set into permanent state from working set") + } else { + self.delta.vars[var_id - num_permanent_vars] = ty; + } + } + pub fn get_variable(&self, var_id: VarId) -> &Type { let num_permanent_vars = self.permanent_state.num_vars(); if var_id < num_permanent_vars { diff --git a/src/signature.rs b/src/signature.rs index f83df6871e..0633dc3be9 100644 --- a/src/signature.rs +++ b/src/signature.rs @@ -1,4 +1,4 @@ -use crate::{parser::SyntaxShape, Declaration}; +use crate::{parser::SyntaxShape, Declaration, VarId}; #[derive(Debug, Clone)] pub struct Flag { @@ -7,6 +7,8 @@ pub struct Flag { pub arg: Option, pub required: bool, pub desc: String, + // For custom commands + pub var_id: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -14,6 +16,8 @@ pub struct PositionalArg { pub name: String, pub desc: String, pub shape: SyntaxShape, + // For custom commands + pub var_id: Option, } #[derive(Clone, Debug)] @@ -75,6 +79,7 @@ impl Signature { name: name.into(), desc: desc.into(), shape: shape.into(), + var_id: None, }); self @@ -91,6 +96,7 @@ impl Signature { name: name.into(), desc: desc.into(), shape: shape.into(), + var_id: None, }); self @@ -114,6 +120,7 @@ impl Signature { arg: Some(shape.into()), required: false, desc: desc.into(), + var_id: None, }); self @@ -137,6 +144,7 @@ impl Signature { arg: Some(shape.into()), required: true, desc: desc.into(), + var_id: None, }); self @@ -163,6 +171,7 @@ impl Signature { arg: None, required: false, desc: desc.into(), + var_id: None, }); self }