mirror of
https://github.com/PaddiM8/kalker.git
synced 2025-01-22 13:08:35 +01:00
Fixed inversion for (function) variables and added unit tests for the inverter.
This commit is contained in:
parent
22816bcdc3
commit
643509ce4a
@ -24,7 +24,7 @@ fn invert(
|
||||
}
|
||||
Expr::Unary(op, expr) => invert_unary(target_expr, op, &expr),
|
||||
Expr::Unit(identifier, expr) => invert_unit(target_expr, &identifier, &expr),
|
||||
Expr::Var(_) => Ok((target_expr, expr.clone())),
|
||||
Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier),
|
||||
Expr::Group(expr) => Ok((target_expr, *expr.clone())),
|
||||
Expr::FnCall(identifier, arguments) => {
|
||||
invert_fn_call(target_expr, symbol_table, &identifier, arguments)
|
||||
@ -50,7 +50,7 @@ fn invert_binary(
|
||||
target_expr,
|
||||
symbol_table,
|
||||
left,
|
||||
op,
|
||||
&TokenKind::Plus,
|
||||
&multiply_into(&Expr::Literal(String::from("-1")), inside_group)?,
|
||||
);
|
||||
}
|
||||
@ -107,7 +107,7 @@ fn invert_binary(
|
||||
|
||||
// If the left expression contains the unit, invert the right one instead,
|
||||
// since the unit should not be moved.
|
||||
if contains_the_unit(left) {
|
||||
if contains_the_unit(symbol_table, left) {
|
||||
return Ok(invert(
|
||||
Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())),
|
||||
symbol_table,
|
||||
@ -151,6 +151,18 @@ fn invert_unit(
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn invert_var(
|
||||
target_expr: Expr,
|
||||
symbol_table: &mut SymbolTable,
|
||||
identifier: &str,
|
||||
) -> Result<(Expr, Expr), CalcError> {
|
||||
if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier).cloned() {
|
||||
invert(target_expr, symbol_table, &var_expr)
|
||||
} else {
|
||||
Ok((target_expr, Expr::Var(identifier.into())))
|
||||
}
|
||||
}
|
||||
|
||||
fn invert_fn_call(
|
||||
target_expr: Expr,
|
||||
symbol_table: &mut SymbolTable,
|
||||
@ -165,7 +177,7 @@ fn invert_fn_call(
|
||||
return Err(CalcError::UndefinedFn(identifier.into()));
|
||||
};
|
||||
|
||||
// Make sure the input-expression is valid.
|
||||
// Make sure the input is valid.
|
||||
if parameters.len() != arguments.len() {
|
||||
return Err(CalcError::IncorrectAmountOfArguments(
|
||||
parameters.len(),
|
||||
@ -187,17 +199,26 @@ fn invert_fn_call(
|
||||
invert(target_expr, symbol_table, &body)
|
||||
}
|
||||
|
||||
fn contains_the_unit(expr: &Expr) -> bool {
|
||||
fn contains_the_unit(symbol_table: &SymbolTable, expr: &Expr) -> bool {
|
||||
// Recursively scan the expression for the unit.
|
||||
match expr {
|
||||
Expr::Binary(left, _, right) => contains_the_unit(left) || contains_the_unit(right),
|
||||
Expr::Unary(_, expr) => contains_the_unit(expr),
|
||||
Expr::Unit(_, expr) => contains_the_unit(expr),
|
||||
Expr::Var(identifier) => identifier == DECL_UNIT,
|
||||
Expr::Group(expr) => contains_the_unit(expr),
|
||||
Expr::Binary(left, _, right) => {
|
||||
contains_the_unit(symbol_table, left) || contains_the_unit(symbol_table, right)
|
||||
}
|
||||
Expr::Unary(_, expr) => contains_the_unit(symbol_table, expr),
|
||||
Expr::Unit(_, expr) => contains_the_unit(symbol_table, expr),
|
||||
Expr::Var(identifier) => {
|
||||
identifier == DECL_UNIT
|
||||
|| if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier) {
|
||||
contains_the_unit(symbol_table, var_expr)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Expr::Group(expr) => contains_the_unit(symbol_table, expr),
|
||||
Expr::FnCall(_, args) => {
|
||||
for arg in args {
|
||||
if contains_the_unit(arg) {
|
||||
if contains_the_unit(symbol_table, arg) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -235,3 +256,200 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, CalcError> {
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_imports, dead_code)] // Getting warnings for some reason
|
||||
mod tests {
|
||||
use crate::ast::Expr;
|
||||
use crate::lexer::TokenKind::*;
|
||||
use crate::symbol_table::SymbolTable;
|
||||
use crate::test_helpers::*;
|
||||
|
||||
fn decl_unit() -> Box<Expr> {
|
||||
Box::new(Expr::Var(crate::parser::DECL_UNIT.into()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary() {
|
||||
let ladd = binary(decl_unit(), Plus, literal("1"));
|
||||
let lsub = binary(decl_unit(), Minus, literal("1"));
|
||||
let lmul = binary(decl_unit(), Star, literal("1"));
|
||||
let ldiv = binary(decl_unit(), Slash, literal("1"));
|
||||
|
||||
let radd = binary(literal("1"), Plus, decl_unit());
|
||||
let rsub = binary(literal("1"), Minus, decl_unit());
|
||||
let rmul = binary(literal("1"), Star, decl_unit());
|
||||
let rdiv = binary(literal("1"), Slash, decl_unit());
|
||||
|
||||
let mut symbol_table = SymbolTable::new();
|
||||
assert_eq!(
|
||||
ladd.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Minus, literal("1"))
|
||||
);
|
||||
assert_eq!(
|
||||
lsub.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Plus, literal("1"))
|
||||
);
|
||||
assert_eq!(
|
||||
lmul.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Slash, literal("1"))
|
||||
);
|
||||
assert_eq!(
|
||||
ldiv.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Star, literal("1"))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
radd.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Minus, literal("1"))
|
||||
);
|
||||
assert_eq!(
|
||||
rsub.invert(&mut symbol_table).unwrap(),
|
||||
*unary(Minus, binary(decl_unit(), Plus, literal("1")))
|
||||
);
|
||||
assert_eq!(
|
||||
rmul.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Slash, literal("1"))
|
||||
);
|
||||
assert_eq!(
|
||||
rdiv.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Star, literal("1"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary() {
|
||||
let neg = unary(Minus, decl_unit());
|
||||
|
||||
let mut symbol_table = SymbolTable::new();
|
||||
assert_eq!(neg.invert(&mut symbol_table).unwrap(), *neg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_call() {
|
||||
let call_with_literal = binary(fn_call("f", vec![*literal("2")]), Plus, decl_unit());
|
||||
let call_with_decl_unit = fn_call("f", vec![*decl_unit()]);
|
||||
let call_with_decl_unit_and_literal =
|
||||
fn_call("f", vec![*binary(decl_unit(), Plus, literal("2"))]);
|
||||
let decl = fn_decl(
|
||||
"f",
|
||||
vec![String::from("x")],
|
||||
binary(var("x"), Plus, literal("1")),
|
||||
);
|
||||
|
||||
let mut symbol_table = SymbolTable::new();
|
||||
symbol_table.insert(decl);
|
||||
assert_eq!(
|
||||
call_with_literal.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Minus, fn_call("f", vec![*literal("2")])),
|
||||
);
|
||||
assert_eq!(
|
||||
call_with_decl_unit.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Minus, literal("1"))
|
||||
);
|
||||
assert_eq!(
|
||||
call_with_decl_unit_and_literal
|
||||
.invert(&mut symbol_table)
|
||||
.unwrap(),
|
||||
*binary(
|
||||
binary(decl_unit(), Minus, literal("1")),
|
||||
Minus,
|
||||
literal("2")
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_group() {
|
||||
let group_x = binary(
|
||||
group(binary(decl_unit(), Plus, literal("3"))),
|
||||
Star,
|
||||
literal("2"),
|
||||
);
|
||||
let group_unary_minus = binary(
|
||||
literal("2"),
|
||||
Minus,
|
||||
group(binary(decl_unit(), Plus, literal("3"))),
|
||||
);
|
||||
let x_group_add = binary(
|
||||
literal("2"),
|
||||
Star,
|
||||
group(binary(decl_unit(), Plus, literal("3"))),
|
||||
);
|
||||
let x_group_sub = binary(
|
||||
literal("2"),
|
||||
Star,
|
||||
group(binary(decl_unit(), Minus, literal("3"))),
|
||||
);
|
||||
let x_group_mul = binary(
|
||||
literal("2"),
|
||||
Star,
|
||||
group(binary(decl_unit(), Star, literal("3"))),
|
||||
);
|
||||
let x_group_div = binary(
|
||||
literal("2"),
|
||||
Star,
|
||||
group(binary(decl_unit(), Slash, literal("3"))),
|
||||
);
|
||||
|
||||
let mut symbol_table = SymbolTable::new();
|
||||
assert_eq!(
|
||||
group_x.invert(&mut symbol_table).unwrap(),
|
||||
*binary(
|
||||
binary(decl_unit(), Minus, binary(literal("2"), Star, literal("3"))),
|
||||
Slash,
|
||||
literal("2")
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
group_unary_minus.invert(&mut symbol_table).unwrap(),
|
||||
*binary(
|
||||
binary(
|
||||
binary(decl_unit(), Minus, literal("2")),
|
||||
Minus,
|
||||
binary(literal("-1"), Star, literal("3"))
|
||||
),
|
||||
Slash,
|
||||
literal("-1")
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
x_group_add.invert(&mut symbol_table).unwrap(),
|
||||
*binary(
|
||||
binary(decl_unit(), Minus, binary(literal("2"), Star, literal("3"))),
|
||||
Slash,
|
||||
literal("2")
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
x_group_sub.invert(&mut symbol_table).unwrap(),
|
||||
*binary(
|
||||
binary(decl_unit(), Plus, binary(literal("2"), Star, literal("3"))),
|
||||
Slash,
|
||||
literal("2")
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
x_group_mul.invert(&mut symbol_table).unwrap(),
|
||||
*binary(
|
||||
binary(decl_unit(), Slash, literal("3")),
|
||||
Slash,
|
||||
literal("2")
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
x_group_div.invert(&mut symbol_table).unwrap(),
|
||||
*binary(binary(decl_unit(), Star, literal("3")), Slash, literal("2"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_decl_units() {
|
||||
let add_two = binary(decl_unit(), Plus, decl_unit());
|
||||
|
||||
let mut symbol_table = SymbolTable::new();
|
||||
assert_eq!(
|
||||
add_two.invert(&mut symbol_table).unwrap(),
|
||||
*binary(decl_unit(), Slash, literal("2"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user