diff --git a/Cargo.lock b/Cargo.lock index 76e1cbe..57a744e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,6 +140,7 @@ dependencies = [ name = "kalk" version = "0.1.10" dependencies = [ + "lazy_static", "phf", "regex", "rug", diff --git a/kalk/Cargo.toml b/kalk/Cargo.toml index 9fcd679..f89239b 100644 --- a/kalk/Cargo.toml +++ b/kalk/Cargo.toml @@ -15,3 +15,4 @@ phf = { version = "0.8", features = ["macros"] } rug = "1.9.0" test-case = "1.0.0" regex = "1" +lazy_static = "1.4.0" diff --git a/kalk/src/interpreter.rs b/kalk/src/interpreter.rs index 6959929..61fe871 100644 --- a/kalk/src/interpreter.rs +++ b/kalk/src/interpreter.rs @@ -265,10 +265,38 @@ mod tests { const PRECISION: u32 = 53; + lazy_static::lazy_static! { + static ref DEG_RAD_UNIT: Stmt = unit_decl( + "deg", + "rad", + binary( + binary( + var(crate::parser::DECL_UNIT), + TokenKind::Star, + literal("180"), + ), + TokenKind::Slash, + var("pi"), + ), + ); + static ref RAD_DEG_UNIT: Stmt = unit_decl( + "rad", + "deg", + binary( + binary(var(crate::parser::DECL_UNIT), TokenKind::Star, var("pi")), + TokenKind::Slash, + literal("180"), + ), + ); + } + fn interpret(stmt: Stmt) -> Result, CalcError> { let mut symbol_table = SymbolTable::new(); - let mut context = Context::new(&mut symbol_table, "rad", PRECISION); + symbol_table + .insert(DEG_RAD_UNIT.clone()) + .insert(RAD_DEG_UNIT.clone()); + let mut context = Context::new(&mut symbol_table, "rad", PRECISION); context.interpret(vec![stmt]) } @@ -315,14 +343,19 @@ mod tests { fn test_angle_units() { let rad_explicit = Stmt::Expr(fn_call("sin", vec![*unit("rad", literal("1"))])); let deg_explicit = Stmt::Expr(fn_call("sin", vec![*unit("deg", literal("1"))])); - //let implicit = Stmt::Expr(fn_call("sin", vec![*literal("1")])); + let implicit = Stmt::Expr(fn_call("sin", vec![*literal("1")])); assert!(cmp(interpret(rad_explicit).unwrap().unwrap(), 0.84147098)); assert!(cmp(interpret(deg_explicit).unwrap().unwrap(), 0.01745240)); - // TODO: Get this to work. - /*let mut rad_symbol_table = SymbolTable::new(); + let mut rad_symbol_table = SymbolTable::new(); + rad_symbol_table + .insert(DEG_RAD_UNIT.clone()) + .insert(RAD_DEG_UNIT.clone()); let mut deg_symbol_table = SymbolTable::new(); + deg_symbol_table + .insert(DEG_RAD_UNIT.clone()) + .insert(RAD_DEG_UNIT.clone()); let mut rad_context = Context::new(&mut rad_symbol_table, "rad", PRECISION); let mut deg_context = Context::new(&mut deg_symbol_table, "deg", PRECISION); @@ -336,7 +369,7 @@ mod tests { assert!(cmp( deg_context.interpret(vec![implicit]).unwrap().unwrap(), 0.01745240 - ));*/ + )); } #[test] diff --git a/kalk/src/inverter.rs b/kalk/src/inverter.rs index 7dc3966..353deed 100644 --- a/kalk/src/inverter.rs +++ b/kalk/src/inverter.rs @@ -51,7 +51,7 @@ fn invert_binary( symbol_table, left, op, - &multiply_in(&Expr::Literal(String::from("-1")), inside_group)?, + &multiply_into(&Expr::Literal(String::from("-1")), inside_group)?, ); } @@ -64,13 +64,17 @@ fn invert_binary( return invert( target_expr, symbol_table, - &multiply_in(right, inside_group)?, + &multiply_into(right, inside_group)?, ); } // Same as above but left/right switched. if let Expr::Group(inside_group) = right { - return invert(target_expr, symbol_table, &multiply_in(left, inside_group)?); + return invert( + target_expr, + symbol_table, + &multiply_into(left, inside_group)?, + ); } TokenKind::Slash @@ -130,14 +134,15 @@ fn invert_binary( fn invert_unary(target_expr: Expr, op: &TokenKind, expr: &Expr) -> Result<(Expr, Expr), CalcError> { match op { TokenKind::Minus => Ok(( + // Make the target expression negative Expr::Unary(TokenKind::Minus, Box::new(target_expr)), - expr.clone(), + expr.clone(), // And then continue inverting the inner-expression. )), _ => unimplemented!(), } } -// Not necessary yet +// TODO: Implement fn invert_unit( _target_expr: Expr, _identifier: &str, @@ -152,6 +157,7 @@ fn invert_fn_call( identifier: &str, arguments: &Vec, ) -> Result<(Expr, Expr), CalcError> { + // Get the function definition from the symbol table. let (parameters, body) = if let Some(Stmt::FnDecl(_, parameters, body)) = symbol_table.get_fn(identifier).cloned() { (parameters, body) @@ -159,6 +165,7 @@ fn invert_fn_call( return Err(CalcError::UndefinedFn(identifier.into())); }; + // Make sure the input-expression is valid. if parameters.len() != arguments.len() { return Err(CalcError::IncorrectAmountOfArguments( parameters.len(), @@ -167,6 +174,7 @@ fn invert_fn_call( )); } + // Make the parameters usable as variables inside the function. let mut parameters_iter = parameters.iter(); for argument in arguments { symbol_table.insert(Stmt::VarDecl( @@ -175,10 +183,12 @@ fn invert_fn_call( )); } + // Invert everything in the function body. invert(target_expr, symbol_table, &body) } fn contains_the_unit(expr: &Expr) -> bool { + // Recursively scan the expression for the unit. match expr { Expr::Binary(left, _, right) => contains_the_unit(left) || contains_the_unit(right), Expr::Unary(_, expr) => contains_the_unit(expr), @@ -198,21 +208,25 @@ fn contains_the_unit(expr: &Expr) -> bool { } } -fn multiply_in(expr: &Expr, base_expr: &Expr) -> Result { +/// Multiply an expression into a group. +fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result { match base_expr { Expr::Binary(left, op, right) => match op { + // If + or -, multiply the expression with each term. TokenKind::Plus | TokenKind::Minus => Ok(Expr::Binary( - Box::new(multiply_in(expr, &left)?), + Box::new(multiply_into(expr, &left)?), op.clone(), - Box::new(multiply_in(expr, &right)?), + Box::new(multiply_into(expr, &right)?), )), + // If * or /, only multiply with the first factor. TokenKind::Star | TokenKind::Slash => Ok(Expr::Binary( - Box::new(multiply_in(expr, &left)?), + Box::new(multiply_into(expr, &left)?), op.clone(), right.clone(), )), _ => unimplemented!(), }, + // If it's a literal, just multiply them together. Expr::Literal(_) | Expr::Var(_) => Ok(Expr::Binary( Box::new(expr.clone()), TokenKind::Star, diff --git a/kalk/src/parser.rs b/kalk/src/parser.rs index 5858316..e78aa28 100644 --- a/kalk/src/parser.rs +++ b/kalk/src/parser.rs @@ -414,6 +414,7 @@ mod tests { fn parse_with_context(context: &mut Context, tokens: Vec) -> Result { context.tokens = tokens; + context.pos = 0; parse_stmt(context) } @@ -421,6 +422,7 @@ mod tests { fn parse(tokens: Vec) -> Result { let mut context = Context::new(); context.tokens = tokens; + context.pos = 0; parse_stmt(&mut context) } @@ -448,6 +450,7 @@ mod tests { token(Slash, ""), token(Literal, "5"), token(ClosedParenthesis, ""), + token(EOF, ""), ]; assert_eq!( @@ -480,6 +483,7 @@ mod tests { token(Literal, "4"), token(Plus, ""), token(Literal, "5"), + token(EOF, ""), ]; assert_eq!( @@ -500,20 +504,20 @@ mod tests { ); } - /*#[test_case(Deg)] - #[test_case(Rad)] - fn test_unary(angle_unit: TokenKind) { - let tokens = vec![ - token(Minus, ""), - token(Literal, "1"), - token(angle_unit.clone(), ""), - ]; + #[test] + fn test_unit() { + let tokens = vec![token(Literal, "1"), token(Identifier, "a")]; + + let mut context = Context::new(); + context + .symbol_table + .insert(unit_decl("a", "b", var(super::DECL_UNIT))); assert_eq!( - parse(tokens).unwrap(), - Stmt::Expr(unary(Minus, Box::new(Expr::Unit(literal("1"), angle_unit)))) + parse_with_context(&mut context, tokens).unwrap(), + Stmt::Expr(unit("a", literal("1"))) ); - }*/ + } #[test] fn test_var_decl() { @@ -523,6 +527,7 @@ mod tests { token(Literal, "1"), token(Plus, ""), token(Literal, "2"), + token(EOF, ""), ]; assert_eq!( @@ -542,6 +547,7 @@ mod tests { token(Literal, "1"), token(Plus, ""), token(Literal, "2"), + token(EOF, ""), ]; assert_eq!( @@ -565,6 +571,7 @@ mod tests { token(ClosedParenthesis, ""), token(Plus, ""), token(Literal, "3"), + token(EOF, ""), ]; let mut context = Context::new(); diff --git a/kalk/src/symbol_table.rs b/kalk/src/symbol_table.rs index 4ab0e9d..aead965 100644 --- a/kalk/src/symbol_table.rs +++ b/kalk/src/symbol_table.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; #[derive(Debug)] pub struct SymbolTable { - hashmap: HashMap, - unit_types: HashMap, + pub(crate) hashmap: HashMap, + pub(crate) unit_types: HashMap, } impl SymbolTable { @@ -15,22 +15,24 @@ impl SymbolTable { } } - pub fn insert(&mut self, value: Stmt) -> Option { + pub fn insert(&mut self, value: Stmt) -> &mut Self { match &value { Stmt::VarDecl(identifier, _) => { - self.hashmap.insert(format!("var.{}", identifier), value) + self.hashmap.insert(format!("var.{}", identifier), value); } Stmt::UnitDecl(identifier, to_unit, _) => { self.unit_types.insert(identifier.to_string(), ()); self.unit_types.insert(to_unit.to_string(), ()); self.hashmap - .insert(format!("unit.{}.{}", identifier, to_unit), value) + .insert(format!("unit.{}.{}", identifier, to_unit), value); } Stmt::FnDecl(identifier, _, _) => { - self.hashmap.insert(format!("fn.{}", identifier), value) + self.hashmap.insert(format!("fn.{}", identifier), value); } _ => panic!("Can only insert VarDecl, UnitDecl and FnDecl into symbol table."), } + + self } pub fn get_var(&self, key: &str) -> Option<&Stmt> {