2021-05-17 20:17:34 +02:00
|
|
|
use crate::ast::Identifier;
|
2020-06-13 20:06:21 +02:00
|
|
|
use crate::ast::{Expr, Stmt};
|
2022-04-23 23:43:06 +02:00
|
|
|
use crate::errors::KalkError;
|
2020-06-13 19:01:33 +02:00
|
|
|
use crate::lexer::TokenKind;
|
2020-06-18 16:20:18 +02:00
|
|
|
use crate::prelude;
|
2020-06-13 20:06:21 +02:00
|
|
|
use crate::symbol_table::SymbolTable;
|
2020-12-10 23:40:29 +01:00
|
|
|
use lazy_static::lazy_static;
|
|
|
|
use std::collections::HashMap;
|
2020-06-13 19:01:33 +02:00
|
|
|
|
2020-12-10 23:40:29 +01:00
|
|
|
lazy_static! {
|
|
|
|
pub static ref INVERSE_UNARY_FUNCS: HashMap<&'static str, &'static str> = {
|
|
|
|
let mut m = HashMap::new();
|
|
|
|
m.insert("cos", "acos");
|
2021-05-22 20:47:10 +02:00
|
|
|
m.insert("csc", "acsc");
|
|
|
|
m.insert("csch", "csch");
|
2020-12-10 23:40:29 +01:00
|
|
|
m.insert("cosh", "acosh");
|
|
|
|
m.insert("cot", "acot");
|
|
|
|
m.insert("coth", "acoth");
|
|
|
|
m.insert("sec", "asec");
|
|
|
|
m.insert("sech", "asech");
|
|
|
|
m.insert("sin", "asin");
|
|
|
|
m.insert("sinh", "asinh");
|
|
|
|
m.insert("tan", "atan");
|
|
|
|
m.insert("tanh", "atanh");
|
2020-06-18 16:20:18 +02:00
|
|
|
|
2020-12-10 23:40:29 +01:00
|
|
|
m.insert("acos", "cos");
|
2021-05-22 20:47:10 +02:00
|
|
|
m.insert("acsc", "csc");
|
|
|
|
m.insert("acsch", "csch");
|
2020-12-10 23:40:29 +01:00
|
|
|
m.insert("acosh", "cosh");
|
|
|
|
m.insert("acot", "cot");
|
|
|
|
m.insert("acoth", "coth");
|
|
|
|
m.insert("asec", "sec");
|
|
|
|
m.insert("asech", "sech");
|
|
|
|
m.insert("asin", "sin");
|
|
|
|
m.insert("asinh", "sinh");
|
|
|
|
m.insert("atan", "tan");
|
|
|
|
m.insert("atanh", "tanh");
|
|
|
|
m
|
|
|
|
};
|
|
|
|
}
|
2020-06-18 16:20:18 +02:00
|
|
|
|
2020-06-13 19:01:33 +02:00
|
|
|
impl Expr {
|
2020-12-14 19:21:30 +01:00
|
|
|
pub fn invert(
|
|
|
|
&self,
|
|
|
|
symbol_table: &mut SymbolTable,
|
|
|
|
unknown_var: &str,
|
2022-03-27 21:32:50 +02:00
|
|
|
) -> Result<Self, KalkError> {
|
2021-05-17 20:17:34 +02:00
|
|
|
let target_expr = Expr::Var(Identifier::from_full_name(unknown_var));
|
2020-12-14 19:21:30 +01:00
|
|
|
let result = invert(target_expr, symbol_table, self, unknown_var);
|
2020-06-14 21:35:56 +02:00
|
|
|
|
2020-06-14 19:23:02 +02:00
|
|
|
Ok(result?.0)
|
|
|
|
}
|
2020-12-14 19:21:30 +01:00
|
|
|
|
|
|
|
pub fn invert_to_target(
|
|
|
|
&self,
|
|
|
|
symbol_table: &mut SymbolTable,
|
|
|
|
target_expr: Expr,
|
|
|
|
unknown_var: &str,
|
2022-03-27 21:32:50 +02:00
|
|
|
) -> Result<Self, KalkError> {
|
2020-12-14 19:21:30 +01:00
|
|
|
let x = invert(target_expr, symbol_table, self, unknown_var)?;
|
|
|
|
Ok(x.0)
|
|
|
|
}
|
2020-06-14 19:23:02 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
fn invert(
|
|
|
|
target_expr: Expr,
|
|
|
|
symbol_table: &mut SymbolTable,
|
|
|
|
expr: &Expr,
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var: &str,
|
2022-03-27 21:32:50 +02:00
|
|
|
) -> Result<(Expr, Expr), KalkError> {
|
2020-06-14 19:23:02 +02:00
|
|
|
match expr {
|
|
|
|
Expr::Binary(left, op, right) => {
|
2022-01-16 20:58:00 +01:00
|
|
|
invert_binary(target_expr, symbol_table, left, op, right, unknown_var)
|
2020-06-14 19:23:02 +02:00
|
|
|
}
|
2022-01-16 20:58:00 +01:00
|
|
|
Expr::Unary(op, expr) => invert_unary(target_expr, op, expr),
|
2020-12-14 19:21:30 +01:00
|
|
|
Expr::Unit(identifier, expr) => {
|
2022-01-16 20:58:00 +01:00
|
|
|
invert_unit(target_expr, symbol_table, identifier, expr, unknown_var)
|
2020-06-13 19:01:33 +02:00
|
|
|
}
|
2020-12-14 19:21:30 +01:00
|
|
|
Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier, unknown_var),
|
|
|
|
Expr::Group(expr) => Ok((target_expr, *expr.clone())),
|
|
|
|
Expr::FnCall(identifier, arguments) => invert_fn_call(
|
|
|
|
target_expr,
|
|
|
|
symbol_table,
|
2021-05-17 20:17:34 +02:00
|
|
|
identifier,
|
2020-12-14 19:21:30 +01:00
|
|
|
arguments,
|
|
|
|
unknown_var,
|
|
|
|
),
|
2020-06-14 19:23:02 +02:00
|
|
|
Expr::Literal(_) => Ok((target_expr, expr.clone())),
|
2022-03-27 21:32:50 +02:00
|
|
|
Expr::Piecewise(_) => Err(KalkError::UnableToInvert(String::from("Piecewise"))),
|
|
|
|
Expr::Vector(_) => Err(KalkError::UnableToInvert(String::from("Vector"))),
|
|
|
|
Expr::Matrix(_) => Err(KalkError::UnableToInvert(String::from("Matrix"))),
|
|
|
|
Expr::Indexer(_, _) => Err(KalkError::UnableToInvert(String::from("Inverter"))),
|
2022-01-16 00:33:26 +01:00
|
|
|
Expr::Comprehension(_, _, _) => {
|
2022-03-27 21:32:50 +02:00
|
|
|
Err(KalkError::UnableToInvert(String::from("Comprehension")))
|
2022-01-16 00:33:26 +01:00
|
|
|
}
|
2022-04-24 21:23:29 +02:00
|
|
|
Expr::Equation(_, _, _) => Err(KalkError::UnableToInvert(String::from("Equation"))),
|
2020-06-13 19:01:33 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-13 20:06:21 +02:00
|
|
|
fn invert_binary(
|
2020-06-14 19:23:02 +02:00
|
|
|
target_expr: Expr,
|
2020-06-13 20:06:21 +02:00
|
|
|
symbol_table: &mut SymbolTable,
|
|
|
|
left: &Expr,
|
|
|
|
op: &TokenKind,
|
|
|
|
right: &Expr,
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var: &str,
|
2022-03-27 21:32:50 +02:00
|
|
|
) -> Result<(Expr, Expr), KalkError> {
|
2020-06-13 19:01:33 +02:00
|
|
|
let op_inv = match op {
|
|
|
|
TokenKind::Plus => TokenKind::Minus,
|
2020-06-14 21:35:56 +02:00
|
|
|
TokenKind::Minus => {
|
2020-06-14 21:54:39 +02:00
|
|
|
// Eg. a-(b+c)
|
|
|
|
// Multiply "-1" into the group, resulting in it becoming a normal expression. Then invert it normally.
|
2020-06-14 21:35:56 +02:00
|
|
|
if let Expr::Group(inside_group) = right {
|
|
|
|
return invert_binary(
|
|
|
|
target_expr,
|
|
|
|
symbol_table,
|
|
|
|
left,
|
2020-06-17 17:45:46 +02:00
|
|
|
&TokenKind::Plus,
|
2020-12-13 15:52:22 +01:00
|
|
|
&multiply_into(&Expr::Literal(-1f64), inside_group)?,
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var,
|
2020-06-14 21:35:56 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
TokenKind::Plus
|
|
|
|
}
|
2020-06-14 19:23:02 +02:00
|
|
|
TokenKind::Star => {
|
2020-06-14 21:54:39 +02:00
|
|
|
// If the left expression is a group, multiply the right expression into it, dissolving the group.
|
|
|
|
// It can then be inverted normally.
|
2020-06-14 19:23:02 +02:00
|
|
|
if let Expr::Group(inside_group) = left {
|
|
|
|
return invert(
|
|
|
|
target_expr,
|
|
|
|
symbol_table,
|
2020-06-15 21:27:47 +02:00
|
|
|
&multiply_into(right, inside_group)?,
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var,
|
2020-06-14 19:23:02 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
2020-06-14 21:54:39 +02:00
|
|
|
// Same as above but left/right switched.
|
2020-06-14 19:23:02 +02:00
|
|
|
if let Expr::Group(inside_group) = right {
|
2020-06-15 21:27:47 +02:00
|
|
|
return invert(
|
|
|
|
target_expr,
|
|
|
|
symbol_table,
|
|
|
|
&multiply_into(left, inside_group)?,
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var,
|
2020-06-15 21:27:47 +02:00
|
|
|
);
|
2020-06-14 19:23:02 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TokenKind::Slash
|
|
|
|
}
|
|
|
|
TokenKind::Slash => {
|
2020-06-14 21:54:39 +02:00
|
|
|
// Eg. (a+b)/c
|
|
|
|
// Just dissolve the group. Nothing more needs to be done mathematically.
|
2020-06-14 19:23:02 +02:00
|
|
|
if let Expr::Group(inside_group) = left {
|
|
|
|
return invert(
|
|
|
|
target_expr,
|
|
|
|
symbol_table,
|
2022-01-16 20:58:00 +01:00
|
|
|
&Expr::Binary(inside_group.clone(), *op, Box::new(right.clone())),
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var,
|
2020-06-14 19:23:02 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
2020-06-14 21:54:39 +02:00
|
|
|
// Eg. a/(b+c)
|
|
|
|
// Same as above.
|
2020-06-14 19:23:02 +02:00
|
|
|
if let Expr::Group(inside_group) = right {
|
|
|
|
return invert(
|
|
|
|
target_expr,
|
|
|
|
symbol_table,
|
2022-01-16 20:58:00 +01:00
|
|
|
&Expr::Binary(Box::new(left.clone()), *op, inside_group.clone()),
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var,
|
2020-06-14 19:23:02 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
TokenKind::Star
|
|
|
|
}
|
2020-12-14 19:21:30 +01:00
|
|
|
TokenKind::Power => {
|
|
|
|
return if contains_var(symbol_table, left, unknown_var) {
|
|
|
|
invert(
|
2021-05-17 20:17:34 +02:00
|
|
|
Expr::FnCall(
|
|
|
|
Identifier::from_full_name("root"),
|
|
|
|
vec![target_expr, right.clone()],
|
|
|
|
),
|
2020-12-14 19:21:30 +01:00
|
|
|
symbol_table,
|
|
|
|
right,
|
|
|
|
unknown_var,
|
|
|
|
)
|
|
|
|
} else {
|
|
|
|
invert(
|
2021-05-17 20:17:34 +02:00
|
|
|
Expr::FnCall(
|
|
|
|
Identifier::from_full_name("log"),
|
|
|
|
vec![target_expr, left.clone()],
|
|
|
|
),
|
2020-12-14 19:21:30 +01:00
|
|
|
symbol_table,
|
|
|
|
right,
|
|
|
|
unknown_var,
|
|
|
|
)
|
|
|
|
};
|
|
|
|
}
|
2022-03-27 21:32:50 +02:00
|
|
|
_ => return Err(KalkError::UnableToInvert(String::new())),
|
2020-06-13 19:01:33 +02:00
|
|
|
};
|
|
|
|
|
2020-06-14 21:54:39 +02:00
|
|
|
// If the left expression contains the unit, invert the right one instead,
|
|
|
|
// since the unit should not be moved.
|
2020-12-14 19:21:30 +01:00
|
|
|
if contains_var(symbol_table, left, unknown_var) {
|
2020-06-17 21:28:54 +02:00
|
|
|
// But if the right expression *also* contains the unit,
|
|
|
|
// throw an error, since it can't handle this yet.
|
2020-12-14 19:21:30 +01:00
|
|
|
if contains_var(symbol_table, right, unknown_var) {
|
2022-03-27 21:32:50 +02:00
|
|
|
return Err(KalkError::UnableToInvert(String::from(
|
2020-06-18 16:20:18 +02:00
|
|
|
"Expressions with several instances of an unknown variable (this might be supported in the future). Try simplifying the expression.",
|
2020-06-17 21:28:54 +02:00
|
|
|
)));
|
|
|
|
}
|
|
|
|
|
2022-01-16 20:58:00 +01:00
|
|
|
return invert(
|
2020-06-14 19:23:02 +02:00
|
|
|
Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())),
|
|
|
|
symbol_table,
|
|
|
|
left,
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var,
|
2022-01-16 20:58:00 +01:00
|
|
|
);
|
2020-06-14 19:23:02 +02:00
|
|
|
}
|
2020-06-13 19:01:33 +02:00
|
|
|
|
2020-06-14 21:54:39 +02:00
|
|
|
// Otherwise, invert the left side.
|
2020-06-14 21:35:56 +02:00
|
|
|
let final_target_expr = Expr::Binary(Box::new(target_expr), op_inv, Box::new(left.clone()));
|
2022-01-16 20:58:00 +01:00
|
|
|
invert(
|
2020-06-14 21:54:39 +02:00
|
|
|
// Eg. 2-a
|
|
|
|
// If the operator is minus (and the left expression is being inverted),
|
|
|
|
// make the target expression negative to keep balance.
|
|
|
|
if let TokenKind::Minus = op {
|
2020-06-14 21:35:56 +02:00
|
|
|
Expr::Unary(TokenKind::Minus, Box::new(final_target_expr))
|
|
|
|
} else {
|
|
|
|
final_target_expr
|
|
|
|
},
|
2020-06-14 19:23:02 +02:00
|
|
|
symbol_table,
|
2020-06-14 21:54:39 +02:00
|
|
|
right, // Then invert the right expression.
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var,
|
2022-01-16 20:58:00 +01:00
|
|
|
)
|
2020-06-13 19:01:33 +02:00
|
|
|
}
|
|
|
|
|
2022-03-27 21:32:50 +02:00
|
|
|
fn invert_unary(target_expr: Expr, op: &TokenKind, expr: &Expr) -> Result<(Expr, Expr), KalkError> {
|
2020-06-14 21:35:56 +02:00
|
|
|
match op {
|
|
|
|
TokenKind::Minus => Ok((
|
2020-06-15 21:27:47 +02:00
|
|
|
// Make the target expression negative
|
2020-06-14 21:35:56 +02:00
|
|
|
Expr::Unary(TokenKind::Minus, Box::new(target_expr)),
|
2020-06-15 21:27:47 +02:00
|
|
|
expr.clone(), // And then continue inverting the inner-expression.
|
2020-06-14 21:35:56 +02:00
|
|
|
)),
|
2022-03-27 21:32:50 +02:00
|
|
|
_ => Err(KalkError::UnableToInvert(String::new())),
|
2020-06-14 21:35:56 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-14 19:23:02 +02:00
|
|
|
fn invert_unit(
|
2020-12-14 19:21:30 +01:00
|
|
|
target_expr: Expr,
|
|
|
|
symbol_table: &mut SymbolTable,
|
|
|
|
identifier: &str,
|
|
|
|
expr: &Expr,
|
|
|
|
unknown_var: &str,
|
2022-03-27 21:32:50 +02:00
|
|
|
) -> Result<(Expr, Expr), KalkError> {
|
2020-12-14 19:21:30 +01:00
|
|
|
let x = Expr::Binary(
|
|
|
|
Box::new(target_expr),
|
|
|
|
TokenKind::ToKeyword,
|
2021-05-17 20:17:34 +02:00
|
|
|
Box::new(Expr::Var(Identifier::from_full_name(identifier))),
|
2020-12-14 19:21:30 +01:00
|
|
|
);
|
|
|
|
invert(x, symbol_table, expr, unknown_var)
|
2020-06-13 19:01:33 +02:00
|
|
|
}
|
|
|
|
|
2020-06-17 17:45:46 +02:00
|
|
|
fn invert_var(
|
|
|
|
target_expr: Expr,
|
|
|
|
symbol_table: &mut SymbolTable,
|
2021-05-17 20:17:34 +02:00
|
|
|
identifier: &Identifier,
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var: &str,
|
2022-03-27 21:32:50 +02:00
|
|
|
) -> Result<(Expr, Expr), KalkError> {
|
2021-05-17 20:17:34 +02:00
|
|
|
if identifier.full_name == unknown_var {
|
2021-05-17 20:36:53 +02:00
|
|
|
Ok((target_expr, Expr::Var(identifier.clone())))
|
2021-05-17 20:17:34 +02:00
|
|
|
} else if let Some(Stmt::VarDecl(_, var_expr)) =
|
|
|
|
symbol_table.get_var(&identifier.full_name).cloned()
|
|
|
|
{
|
2020-12-14 19:21:30 +01:00
|
|
|
invert(target_expr, symbol_table, &var_expr, unknown_var)
|
2020-06-17 17:45:46 +02:00
|
|
|
} else {
|
2021-05-17 20:36:53 +02:00
|
|
|
Ok((target_expr, Expr::Var(identifier.clone())))
|
2020-06-17 17:45:46 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-13 22:18:37 +02:00
|
|
|
fn invert_fn_call(
|
2020-06-14 19:23:02 +02:00
|
|
|
target_expr: Expr,
|
2020-06-13 22:18:37 +02:00
|
|
|
symbol_table: &mut SymbolTable,
|
2021-05-17 20:17:34 +02:00
|
|
|
identifier: &Identifier,
|
2022-01-16 20:58:00 +01:00
|
|
|
arguments: &[Expr],
|
2020-12-14 19:21:30 +01:00
|
|
|
unknown_var: &str,
|
2022-03-27 21:32:50 +02:00
|
|
|
) -> Result<(Expr, Expr), KalkError> {
|
2020-06-18 16:20:18 +02:00
|
|
|
// If prelude function
|
|
|
|
match arguments.len() {
|
|
|
|
1 => {
|
2021-05-17 20:17:34 +02:00
|
|
|
if prelude::UNARY_FUNCS.contains_key(identifier.full_name.as_ref() as &str) {
|
|
|
|
if let Some(fn_inv) = INVERSE_UNARY_FUNCS.get(identifier.full_name.as_ref() as &str)
|
|
|
|
{
|
2020-12-14 19:21:30 +01:00
|
|
|
return invert(
|
2021-05-17 20:17:34 +02:00
|
|
|
Expr::FnCall(Identifier::from_full_name(fn_inv), vec![target_expr]),
|
2020-12-14 19:21:30 +01:00
|
|
|
symbol_table,
|
|
|
|
&arguments[0],
|
|
|
|
unknown_var,
|
|
|
|
);
|
2020-06-18 16:20:18 +02:00
|
|
|
} else {
|
2021-05-17 20:17:34 +02:00
|
|
|
match identifier.full_name.as_ref() {
|
2020-06-18 16:20:18 +02:00
|
|
|
"sqrt" => {
|
2020-12-14 19:21:30 +01:00
|
|
|
return invert(
|
2020-06-18 16:20:18 +02:00
|
|
|
Expr::Binary(
|
|
|
|
Box::new(target_expr),
|
|
|
|
TokenKind::Power,
|
2020-12-13 15:52:22 +01:00
|
|
|
Box::new(Expr::Literal(2f64)),
|
2020-06-18 16:20:18 +02:00
|
|
|
),
|
2020-12-14 19:21:30 +01:00
|
|
|
symbol_table,
|
|
|
|
&arguments[0],
|
|
|
|
unknown_var,
|
|
|
|
);
|
2020-06-18 16:20:18 +02:00
|
|
|
}
|
|
|
|
_ => {
|
2022-03-27 21:32:50 +02:00
|
|
|
return Err(KalkError::UnableToInvert(format!(
|
2020-06-18 16:20:18 +02:00
|
|
|
"Function '{}'",
|
2021-05-17 20:17:34 +02:00
|
|
|
identifier.full_name
|
2020-06-18 16:20:18 +02:00
|
|
|
)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
2 => {
|
2021-05-17 20:17:34 +02:00
|
|
|
if prelude::BINARY_FUNCS.contains_key(identifier.full_name.as_ref() as &str) {
|
2022-03-27 21:32:50 +02:00
|
|
|
return Err(KalkError::UnableToInvert(format!(
|
2020-06-18 16:20:18 +02:00
|
|
|
"Function '{}'",
|
2021-05-17 20:17:34 +02:00
|
|
|
identifier.full_name
|
2020-06-18 16:20:18 +02:00
|
|
|
)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
_ => (),
|
|
|
|
}
|
|
|
|
|
2020-06-15 21:27:47 +02:00
|
|
|
// Get the function definition from the symbol table.
|
2021-05-17 20:17:34 +02:00
|
|
|
let (parameters, body) = if let Some(Stmt::FnDecl(_, parameters, body)) =
|
|
|
|
symbol_table.get_fn(&identifier.full_name).cloned()
|
|
|
|
{
|
|
|
|
(parameters, body)
|
|
|
|
} else {
|
2022-03-27 21:32:50 +02:00
|
|
|
return Err(KalkError::UndefinedFn(identifier.full_name.clone()));
|
2021-05-17 20:17:34 +02:00
|
|
|
};
|
2020-06-13 20:06:21 +02:00
|
|
|
|
2020-06-17 17:45:46 +02:00
|
|
|
// Make sure the input is valid.
|
2020-06-13 22:18:37 +02:00
|
|
|
if parameters.len() != arguments.len() {
|
2022-03-27 21:32:50 +02:00
|
|
|
return Err(KalkError::IncorrectAmountOfArguments(
|
2020-06-13 22:18:37 +02:00
|
|
|
parameters.len(),
|
2021-05-17 20:36:53 +02:00
|
|
|
identifier.full_name.clone(),
|
2020-06-13 22:18:37 +02:00
|
|
|
arguments.len(),
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
2020-06-15 21:27:47 +02:00
|
|
|
// Make the parameters usable as variables inside the function.
|
2020-06-13 20:06:21 +02:00
|
|
|
let mut parameters_iter = parameters.iter();
|
|
|
|
for argument in arguments {
|
|
|
|
symbol_table.insert(Stmt::VarDecl(
|
2021-05-17 20:17:34 +02:00
|
|
|
Identifier::from_full_name(¶meters_iter.next().unwrap().to_string()),
|
2020-06-13 20:06:21 +02:00
|
|
|
Box::new(argument.clone()),
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
2020-06-15 21:27:47 +02:00
|
|
|
// Invert everything in the function body.
|
2020-12-14 19:21:30 +01:00
|
|
|
invert(target_expr, symbol_table, &body, unknown_var)
|
2020-06-13 19:01:33 +02:00
|
|
|
}
|
|
|
|
|
2020-12-14 19:21:30 +01:00
|
|
|
pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) -> bool {
|
|
|
|
// Recursively scan the expression for the variable.
|
2020-06-14 19:23:02 +02:00
|
|
|
match expr {
|
2020-06-17 17:45:46 +02:00
|
|
|
Expr::Binary(left, _, right) => {
|
2020-12-14 19:21:30 +01:00
|
|
|
contains_var(symbol_table, left, var_name)
|
|
|
|
|| contains_var(symbol_table, right, var_name)
|
2020-06-17 17:45:46 +02:00
|
|
|
}
|
2020-12-14 19:21:30 +01:00
|
|
|
Expr::Unary(_, expr) => contains_var(symbol_table, expr, var_name),
|
|
|
|
Expr::Unit(_, expr) => contains_var(symbol_table, expr, var_name),
|
2020-06-17 17:45:46 +02:00
|
|
|
Expr::Var(identifier) => {
|
2021-05-17 20:17:34 +02:00
|
|
|
identifier.full_name == var_name
|
|
|
|
|| if let Some(Stmt::VarDecl(_, var_expr)) =
|
|
|
|
symbol_table.get_var(&identifier.full_name)
|
|
|
|
{
|
2020-12-14 19:21:30 +01:00
|
|
|
contains_var(symbol_table, var_expr, var_name)
|
2020-06-17 17:45:46 +02:00
|
|
|
} else {
|
|
|
|
false
|
|
|
|
}
|
|
|
|
}
|
2020-12-14 19:21:30 +01:00
|
|
|
Expr::Group(expr) => contains_var(symbol_table, expr, var_name),
|
2020-06-14 19:23:02 +02:00
|
|
|
Expr::FnCall(_, args) => {
|
|
|
|
for arg in args {
|
2020-12-14 19:21:30 +01:00
|
|
|
if contains_var(symbol_table, arg, var_name) {
|
2020-06-14 19:23:02 +02:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
false
|
|
|
|
}
|
|
|
|
Expr::Literal(_) => false,
|
2021-05-31 18:55:37 +02:00
|
|
|
Expr::Piecewise(_) => true, // Let it try to invert this. It will just display the error message.
|
2022-01-05 02:49:12 +01:00
|
|
|
Expr::Vector(items) => items
|
|
|
|
.iter()
|
|
|
|
.any(|x| contains_var(symbol_table, x, var_name)),
|
2022-01-07 00:39:29 +01:00
|
|
|
Expr::Matrix(rows) => rows
|
|
|
|
.iter()
|
|
|
|
.any(|row| row.iter().any(|x| contains_var(symbol_table, x, var_name))),
|
2022-01-05 22:41:41 +01:00
|
|
|
Expr::Indexer(_, _) => false,
|
2022-01-16 00:33:26 +01:00
|
|
|
Expr::Comprehension(_, _, _) => false,
|
2022-04-24 21:23:29 +02:00
|
|
|
Expr::Equation(_, _, _) => false,
|
2020-06-14 19:23:02 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-15 21:27:47 +02:00
|
|
|
/// Multiply an expression into a group.
|
2022-03-27 21:32:50 +02:00
|
|
|
fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, KalkError> {
|
2020-06-14 19:23:02 +02:00
|
|
|
match base_expr {
|
|
|
|
Expr::Binary(left, op, right) => match op {
|
2020-06-15 21:27:47 +02:00
|
|
|
// If + or -, multiply the expression with each term.
|
2020-06-14 19:23:02 +02:00
|
|
|
TokenKind::Plus | TokenKind::Minus => Ok(Expr::Binary(
|
2022-01-16 20:58:00 +01:00
|
|
|
Box::new(multiply_into(expr, left)?),
|
|
|
|
*op,
|
|
|
|
Box::new(multiply_into(expr, right)?),
|
2020-06-14 19:23:02 +02:00
|
|
|
)),
|
2020-06-15 21:27:47 +02:00
|
|
|
// If * or /, only multiply with the first factor.
|
2020-06-14 19:23:02 +02:00
|
|
|
TokenKind::Star | TokenKind::Slash => Ok(Expr::Binary(
|
2022-01-16 20:58:00 +01:00
|
|
|
Box::new(multiply_into(expr, left)?),
|
|
|
|
*op,
|
2020-06-14 19:23:02 +02:00
|
|
|
right.clone(),
|
|
|
|
)),
|
2022-03-27 21:32:50 +02:00
|
|
|
_ => Err(KalkError::UnableToInvert(String::new())),
|
2020-06-14 19:23:02 +02:00
|
|
|
},
|
2020-06-15 21:27:47 +02:00
|
|
|
// If it's a literal, just multiply them together.
|
2020-06-14 19:23:02 +02:00
|
|
|
Expr::Literal(_) | Expr::Var(_) => Ok(Expr::Binary(
|
|
|
|
Box::new(expr.clone()),
|
|
|
|
TokenKind::Star,
|
|
|
|
Box::new(base_expr.clone()),
|
|
|
|
)),
|
2022-03-27 21:32:50 +02:00
|
|
|
Expr::Group(_) => Err(KalkError::UnableToInvert(String::from(
|
2020-06-18 16:20:18 +02:00
|
|
|
"Parenthesis multiplied with parenthesis (this should be possible in the future).",
|
2020-06-17 21:28:54 +02:00
|
|
|
))),
|
2022-03-27 21:32:50 +02:00
|
|
|
_ => Err(KalkError::UnableToInvert(String::new())),
|
2020-06-14 19:23:02 +02:00
|
|
|
}
|
2020-06-13 19:01:33 +02:00
|
|
|
}
|
2020-06-17 17:45:46 +02:00
|
|
|
|
|
|
|
#[allow(unused_imports, dead_code)] // Getting warnings for some reason
|
2020-12-30 22:50:39 +01:00
|
|
|
#[cfg(test)]
|
2020-06-17 17:45:46 +02:00
|
|
|
mod tests {
|
|
|
|
use crate::ast::Expr;
|
2021-05-17 20:17:34 +02:00
|
|
|
use crate::ast::Identifier;
|
2020-06-17 17:45:46 +02:00
|
|
|
use crate::lexer::TokenKind::*;
|
2020-12-14 19:21:30 +01:00
|
|
|
use crate::parser::DECL_UNIT;
|
2020-06-17 17:45:46 +02:00
|
|
|
use crate::symbol_table::SymbolTable;
|
|
|
|
use crate::test_helpers::*;
|
2020-12-30 22:50:39 +01:00
|
|
|
use wasm_bindgen_test::*;
|
2020-06-17 17:45:46 +02:00
|
|
|
|
|
|
|
fn decl_unit() -> Box<Expr> {
|
2021-05-17 20:17:34 +02:00
|
|
|
Box::new(Expr::Var(Identifier::from_full_name(
|
|
|
|
crate::parser::DECL_UNIT,
|
|
|
|
)))
|
2020-06-17 17:45:46 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
2020-12-30 22:50:39 +01:00
|
|
|
#[wasm_bindgen_test]
|
2020-06-17 17:45:46 +02:00
|
|
|
fn test_binary() {
|
2020-12-13 15:52:22 +01:00
|
|
|
let ladd = binary(decl_unit(), Plus, literal(1f64));
|
|
|
|
let lsub = binary(decl_unit(), Minus, literal(1f64));
|
|
|
|
let lmul = binary(decl_unit(), Star, literal(1f64));
|
|
|
|
let ldiv = binary(decl_unit(), Slash, literal(1f64));
|
2020-06-17 17:45:46 +02:00
|
|
|
|
2020-12-13 15:52:22 +01:00
|
|
|
let radd = binary(literal(1f64), Plus, decl_unit());
|
|
|
|
let rsub = binary(literal(1f64), Minus, decl_unit());
|
|
|
|
let rmul = binary(literal(1f64), Star, decl_unit());
|
|
|
|
let rdiv = binary(literal(1f64), Slash, decl_unit());
|
2020-06-17 17:45:46 +02:00
|
|
|
|
|
|
|
let mut symbol_table = SymbolTable::new();
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
ladd.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Minus, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
lsub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Plus, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
lmul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Slash, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
ldiv.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Star, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
radd.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Minus, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
rsub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*unary(Minus, binary(decl_unit(), Plus, literal(1f64)))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
rmul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Slash, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
rdiv.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Star, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
2020-12-30 22:50:39 +01:00
|
|
|
#[wasm_bindgen_test]
|
2020-06-17 17:45:46 +02:00
|
|
|
fn test_unary() {
|
|
|
|
let neg = unary(Minus, decl_unit());
|
|
|
|
|
|
|
|
let mut symbol_table = SymbolTable::new();
|
2020-12-14 19:21:30 +01:00
|
|
|
assert_eq!(neg.invert(&mut symbol_table, DECL_UNIT).unwrap(), *neg);
|
2020-06-17 17:45:46 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
2020-12-30 22:50:39 +01:00
|
|
|
#[wasm_bindgen_test]
|
2020-06-17 17:45:46 +02:00
|
|
|
fn test_fn_call() {
|
2020-12-13 15:52:22 +01:00
|
|
|
let call_with_literal = binary(fn_call("f", vec![*literal(2f64)]), Plus, decl_unit());
|
2020-06-17 17:45:46 +02:00
|
|
|
let call_with_decl_unit = fn_call("f", vec![*decl_unit()]);
|
|
|
|
let call_with_decl_unit_and_literal =
|
2020-12-13 15:52:22 +01:00
|
|
|
fn_call("f", vec![*binary(decl_unit(), Plus, literal(2f64))]);
|
2020-06-17 17:45:46 +02:00
|
|
|
let decl = fn_decl(
|
|
|
|
"f",
|
|
|
|
vec![String::from("x")],
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(var("x"), Plus, literal(1f64)),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
|
|
|
|
let mut symbol_table = SymbolTable::new();
|
|
|
|
symbol_table.insert(decl);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
call_with_literal
|
|
|
|
.invert(&mut symbol_table, DECL_UNIT)
|
|
|
|
.unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Minus, fn_call("f", vec![*literal(2f64)])),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
call_with_decl_unit
|
|
|
|
.invert(&mut symbol_table, DECL_UNIT)
|
|
|
|
.unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(decl_unit(), Minus, literal(1f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
assert_eq!(
|
|
|
|
call_with_decl_unit_and_literal
|
2020-12-14 19:21:30 +01:00
|
|
|
.invert(&mut symbol_table, DECL_UNIT)
|
2020-06-17 17:45:46 +02:00
|
|
|
.unwrap(),
|
|
|
|
*binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(decl_unit(), Minus, literal(1f64)),
|
2020-06-17 17:45:46 +02:00
|
|
|
Minus,
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64)
|
2020-06-17 17:45:46 +02:00
|
|
|
)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
2020-12-30 22:50:39 +01:00
|
|
|
#[wasm_bindgen_test]
|
2020-06-17 17:45:46 +02:00
|
|
|
fn test_group() {
|
|
|
|
let group_x = binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
group(binary(decl_unit(), Plus, literal(3f64))),
|
2020-06-17 17:45:46 +02:00
|
|
|
Star,
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
let group_unary_minus = binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64),
|
2020-06-17 17:45:46 +02:00
|
|
|
Minus,
|
2020-12-13 15:52:22 +01:00
|
|
|
group(binary(decl_unit(), Plus, literal(3f64))),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
let x_group_add = binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64),
|
2020-06-17 17:45:46 +02:00
|
|
|
Star,
|
2020-12-13 15:52:22 +01:00
|
|
|
group(binary(decl_unit(), Plus, literal(3f64))),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
let x_group_sub = binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64),
|
2020-06-17 17:45:46 +02:00
|
|
|
Star,
|
2020-12-13 15:52:22 +01:00
|
|
|
group(binary(decl_unit(), Minus, literal(3f64))),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
let x_group_mul = binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64),
|
2020-06-17 17:45:46 +02:00
|
|
|
Star,
|
2020-12-13 15:52:22 +01:00
|
|
|
group(binary(decl_unit(), Star, literal(3f64))),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
let x_group_div = binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64),
|
2020-06-17 17:45:46 +02:00
|
|
|
Star,
|
2020-12-13 15:52:22 +01:00
|
|
|
group(binary(decl_unit(), Slash, literal(3f64))),
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
|
|
|
|
let mut symbol_table = SymbolTable::new();
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
group_x.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-06-17 17:45:46 +02:00
|
|
|
*binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(
|
|
|
|
decl_unit(),
|
|
|
|
Minus,
|
|
|
|
binary(literal(2f64), Star, literal(3f64))
|
|
|
|
),
|
2020-06-17 17:45:46 +02:00
|
|
|
Slash,
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64)
|
2020-06-17 17:45:46 +02:00
|
|
|
)
|
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
group_unary_minus
|
|
|
|
.invert(&mut symbol_table, DECL_UNIT)
|
|
|
|
.unwrap(),
|
2020-06-17 17:45:46 +02:00
|
|
|
*binary(
|
|
|
|
binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(decl_unit(), Minus, literal(2f64)),
|
2020-06-17 17:45:46 +02:00
|
|
|
Minus,
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(literal(-1f64), Star, literal(3f64))
|
2020-06-17 17:45:46 +02:00
|
|
|
),
|
|
|
|
Slash,
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(-1f64)
|
2020-06-17 17:45:46 +02:00
|
|
|
)
|
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
x_group_add.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-06-17 17:45:46 +02:00
|
|
|
*binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(
|
|
|
|
decl_unit(),
|
|
|
|
Minus,
|
|
|
|
binary(literal(2f64), Star, literal(3f64))
|
|
|
|
),
|
2020-06-17 17:45:46 +02:00
|
|
|
Slash,
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64)
|
2020-06-17 17:45:46 +02:00
|
|
|
)
|
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
x_group_sub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-06-17 17:45:46 +02:00
|
|
|
*binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(
|
|
|
|
decl_unit(),
|
|
|
|
Plus,
|
|
|
|
binary(literal(2f64), Star, literal(3f64))
|
|
|
|
),
|
2020-06-17 17:45:46 +02:00
|
|
|
Slash,
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64)
|
2020-06-17 17:45:46 +02:00
|
|
|
)
|
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
x_group_mul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-06-17 17:45:46 +02:00
|
|
|
*binary(
|
2020-12-13 15:52:22 +01:00
|
|
|
binary(decl_unit(), Slash, literal(3f64)),
|
2020-06-17 17:45:46 +02:00
|
|
|
Slash,
|
2020-12-13 15:52:22 +01:00
|
|
|
literal(2f64)
|
2020-06-17 17:45:46 +02:00
|
|
|
)
|
|
|
|
);
|
|
|
|
assert_eq!(
|
2020-12-14 19:21:30 +01:00
|
|
|
x_group_div.invert(&mut symbol_table, DECL_UNIT).unwrap(),
|
2020-12-13 15:52:22 +01:00
|
|
|
*binary(
|
|
|
|
binary(decl_unit(), Star, literal(3f64)),
|
|
|
|
Slash,
|
|
|
|
literal(2f64)
|
|
|
|
)
|
2020-06-17 17:45:46 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
2020-12-30 22:50:39 +01:00
|
|
|
#[wasm_bindgen_test]
|
2020-06-17 17:45:46 +02:00
|
|
|
fn test_multiple_decl_units() {
|
2020-06-18 02:02:48 +02:00
|
|
|
/*let add_two = binary(decl_unit(), Plus, decl_unit());
|
2020-06-17 17:45:46 +02:00
|
|
|
|
|
|
|
let mut symbol_table = SymbolTable::new();
|
|
|
|
assert_eq!(
|
|
|
|
add_two.invert(&mut symbol_table).unwrap(),
|
|
|
|
*binary(decl_unit(), Slash, literal("2"))
|
2020-06-18 02:02:48 +02:00
|
|
|
);*/
|
2020-06-17 17:45:46 +02:00
|
|
|
}
|
|
|
|
}
|