diff --git a/kalk/src/inverter.rs b/kalk/src/inverter.rs index 353deed..ffdd22e 100644 --- a/kalk/src/inverter.rs +++ b/kalk/src/inverter.rs @@ -24,7 +24,7 @@ fn invert( } Expr::Unary(op, expr) => invert_unary(target_expr, op, &expr), Expr::Unit(identifier, expr) => invert_unit(target_expr, &identifier, &expr), - Expr::Var(_) => Ok((target_expr, expr.clone())), + Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier), Expr::Group(expr) => Ok((target_expr, *expr.clone())), Expr::FnCall(identifier, arguments) => { invert_fn_call(target_expr, symbol_table, &identifier, arguments) @@ -50,7 +50,7 @@ fn invert_binary( target_expr, symbol_table, left, - op, + &TokenKind::Plus, &multiply_into(&Expr::Literal(String::from("-1")), inside_group)?, ); } @@ -107,7 +107,7 @@ fn invert_binary( // If the left expression contains the unit, invert the right one instead, // since the unit should not be moved. - if contains_the_unit(left) { + if contains_the_unit(symbol_table, left) { return Ok(invert( Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())), symbol_table, @@ -151,6 +151,18 @@ fn invert_unit( unimplemented!() } +fn invert_var( + target_expr: Expr, + symbol_table: &mut SymbolTable, + identifier: &str, +) -> Result<(Expr, Expr), CalcError> { + if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier).cloned() { + invert(target_expr, symbol_table, &var_expr) + } else { + Ok((target_expr, Expr::Var(identifier.into()))) + } +} + fn invert_fn_call( target_expr: Expr, symbol_table: &mut SymbolTable, @@ -165,7 +177,7 @@ fn invert_fn_call( return Err(CalcError::UndefinedFn(identifier.into())); }; - // Make sure the input-expression is valid. + // Make sure the input is valid. if parameters.len() != arguments.len() { return Err(CalcError::IncorrectAmountOfArguments( parameters.len(), @@ -187,17 +199,26 @@ fn invert_fn_call( invert(target_expr, symbol_table, &body) } -fn contains_the_unit(expr: &Expr) -> bool { +fn contains_the_unit(symbol_table: &SymbolTable, expr: &Expr) -> bool { // Recursively scan the expression for the unit. match expr { - Expr::Binary(left, _, right) => contains_the_unit(left) || contains_the_unit(right), - Expr::Unary(_, expr) => contains_the_unit(expr), - Expr::Unit(_, expr) => contains_the_unit(expr), - Expr::Var(identifier) => identifier == DECL_UNIT, - Expr::Group(expr) => contains_the_unit(expr), + Expr::Binary(left, _, right) => { + contains_the_unit(symbol_table, left) || contains_the_unit(symbol_table, right) + } + Expr::Unary(_, expr) => contains_the_unit(symbol_table, expr), + Expr::Unit(_, expr) => contains_the_unit(symbol_table, expr), + Expr::Var(identifier) => { + identifier == DECL_UNIT + || if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier) { + contains_the_unit(symbol_table, var_expr) + } else { + false + } + } + Expr::Group(expr) => contains_the_unit(symbol_table, expr), Expr::FnCall(_, args) => { for arg in args { - if contains_the_unit(arg) { + if contains_the_unit(symbol_table, arg) { return true; } } @@ -235,3 +256,200 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result { _ => unimplemented!(), } } + +#[allow(unused_imports, dead_code)] // Getting warnings for some reason +mod tests { + use crate::ast::Expr; + use crate::lexer::TokenKind::*; + use crate::symbol_table::SymbolTable; + use crate::test_helpers::*; + + fn decl_unit() -> Box { + Box::new(Expr::Var(crate::parser::DECL_UNIT.into())) + } + + #[test] + fn test_binary() { + let ladd = binary(decl_unit(), Plus, literal("1")); + let lsub = binary(decl_unit(), Minus, literal("1")); + let lmul = binary(decl_unit(), Star, literal("1")); + let ldiv = binary(decl_unit(), Slash, literal("1")); + + let radd = binary(literal("1"), Plus, decl_unit()); + let rsub = binary(literal("1"), Minus, decl_unit()); + let rmul = binary(literal("1"), Star, decl_unit()); + let rdiv = binary(literal("1"), Slash, decl_unit()); + + let mut symbol_table = SymbolTable::new(); + assert_eq!( + ladd.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, literal("1")) + ); + assert_eq!( + lsub.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Plus, literal("1")) + ); + assert_eq!( + lmul.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Slash, literal("1")) + ); + assert_eq!( + ldiv.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Star, literal("1")) + ); + + assert_eq!( + radd.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, literal("1")) + ); + assert_eq!( + rsub.invert(&mut symbol_table).unwrap(), + *unary(Minus, binary(decl_unit(), Plus, literal("1"))) + ); + assert_eq!( + rmul.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Slash, literal("1")) + ); + assert_eq!( + rdiv.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Star, literal("1")) + ); + } + + #[test] + fn test_unary() { + let neg = unary(Minus, decl_unit()); + + let mut symbol_table = SymbolTable::new(); + assert_eq!(neg.invert(&mut symbol_table).unwrap(), *neg); + } + + #[test] + fn test_fn_call() { + let call_with_literal = binary(fn_call("f", vec![*literal("2")]), Plus, decl_unit()); + let call_with_decl_unit = fn_call("f", vec![*decl_unit()]); + let call_with_decl_unit_and_literal = + fn_call("f", vec![*binary(decl_unit(), Plus, literal("2"))]); + let decl = fn_decl( + "f", + vec![String::from("x")], + binary(var("x"), Plus, literal("1")), + ); + + let mut symbol_table = SymbolTable::new(); + symbol_table.insert(decl); + assert_eq!( + call_with_literal.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, fn_call("f", vec![*literal("2")])), + ); + assert_eq!( + call_with_decl_unit.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, literal("1")) + ); + assert_eq!( + call_with_decl_unit_and_literal + .invert(&mut symbol_table) + .unwrap(), + *binary( + binary(decl_unit(), Minus, literal("1")), + Minus, + literal("2") + ) + ); + } + + #[test] + fn test_group() { + let group_x = binary( + group(binary(decl_unit(), Plus, literal("3"))), + Star, + literal("2"), + ); + let group_unary_minus = binary( + literal("2"), + Minus, + group(binary(decl_unit(), Plus, literal("3"))), + ); + let x_group_add = binary( + literal("2"), + Star, + group(binary(decl_unit(), Plus, literal("3"))), + ); + let x_group_sub = binary( + literal("2"), + Star, + group(binary(decl_unit(), Minus, literal("3"))), + ); + let x_group_mul = binary( + literal("2"), + Star, + group(binary(decl_unit(), Star, literal("3"))), + ); + let x_group_div = binary( + literal("2"), + Star, + group(binary(decl_unit(), Slash, literal("3"))), + ); + + let mut symbol_table = SymbolTable::new(); + assert_eq!( + group_x.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Minus, binary(literal("2"), Star, literal("3"))), + Slash, + literal("2") + ) + ); + assert_eq!( + group_unary_minus.invert(&mut symbol_table).unwrap(), + *binary( + binary( + binary(decl_unit(), Minus, literal("2")), + Minus, + binary(literal("-1"), Star, literal("3")) + ), + Slash, + literal("-1") + ) + ); + assert_eq!( + x_group_add.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Minus, binary(literal("2"), Star, literal("3"))), + Slash, + literal("2") + ) + ); + assert_eq!( + x_group_sub.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Plus, binary(literal("2"), Star, literal("3"))), + Slash, + literal("2") + ) + ); + assert_eq!( + x_group_mul.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Slash, literal("3")), + Slash, + literal("2") + ) + ); + assert_eq!( + x_group_div.invert(&mut symbol_table).unwrap(), + *binary(binary(decl_unit(), Star, literal("3")), Slash, literal("2")) + ); + } + + #[test] + fn test_multiple_decl_units() { + let add_two = binary(decl_unit(), Plus, decl_unit()); + + let mut symbol_table = SymbolTable::new(); + assert_eq!( + add_two.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Slash, literal("2")) + ); + } +}