Simple equation solving, mostly using pre-existing logic from the inverter

This commit is contained in:
PaddiM8 2020-12-14 19:21:30 +01:00
parent 34927585c5
commit 2f7af7de90
5 changed files with 217 additions and 94 deletions

View File

@ -1,7 +1,6 @@
use crate::ast::{Expr, Stmt};
use crate::lexer::TokenKind;
use crate::parser::CalcError;
use crate::parser::DECL_UNIT;
use crate::prelude;
use crate::symbol_table::SymbolTable;
use lazy_static::lazy_static;
@ -40,30 +39,51 @@ lazy_static! {
}
impl Expr {
pub fn invert(&self, symbol_table: &mut SymbolTable) -> Result<Self, CalcError> {
let target_expr = Expr::Var(DECL_UNIT.into());
let result = invert(target_expr, symbol_table, self);
pub fn invert(
&self,
symbol_table: &mut SymbolTable,
unknown_var: &str,
) -> Result<Self, CalcError> {
let target_expr = Expr::Var(unknown_var.into());
let result = invert(target_expr, symbol_table, self, unknown_var);
Ok(result?.0)
}
pub fn invert_to_target(
&self,
symbol_table: &mut SymbolTable,
target_expr: Expr,
unknown_var: &str,
) -> Result<Self, CalcError> {
let x = invert(target_expr, symbol_table, self, unknown_var)?;
Ok(x.0)
}
}
fn invert(
target_expr: Expr,
symbol_table: &mut SymbolTable,
expr: &Expr,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> {
match expr {
Expr::Binary(left, op, right) => {
invert_binary(target_expr, symbol_table, &left, op, &right)
invert_binary(target_expr, symbol_table, &left, op, &right, unknown_var)
}
Expr::Unary(op, expr) => invert_unary(target_expr, op, &expr),
Expr::Unit(identifier, expr) => invert_unit(target_expr, &identifier, &expr),
Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier),
Expr::Group(expr) => Ok((target_expr, *expr.clone())),
Expr::FnCall(identifier, arguments) => {
invert_fn_call(target_expr, symbol_table, &identifier, arguments)
Expr::Unit(identifier, expr) => {
invert_unit(target_expr, symbol_table, &identifier, &expr, unknown_var)
}
Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier, unknown_var),
Expr::Group(expr) => Ok((target_expr, *expr.clone())),
Expr::FnCall(identifier, arguments) => invert_fn_call(
target_expr,
symbol_table,
&identifier,
arguments,
unknown_var,
),
Expr::Literal(_) => Ok((target_expr, expr.clone())),
}
}
@ -74,6 +94,7 @@ fn invert_binary(
left: &Expr,
op: &TokenKind,
right: &Expr,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> {
let op_inv = match op {
TokenKind::Plus => TokenKind::Minus,
@ -87,6 +108,7 @@ fn invert_binary(
left,
&TokenKind::Plus,
&multiply_into(&Expr::Literal(-1f64), inside_group)?,
unknown_var,
);
}
@ -100,6 +122,7 @@ fn invert_binary(
target_expr,
symbol_table,
&multiply_into(right, inside_group)?,
unknown_var,
);
}
@ -109,6 +132,7 @@ fn invert_binary(
target_expr,
symbol_table,
&multiply_into(left, inside_group)?,
unknown_var,
);
}
@ -122,6 +146,7 @@ fn invert_binary(
target_expr,
symbol_table,
&Expr::Binary(inside_group.clone(), op.clone(), Box::new(right.clone())),
unknown_var,
);
}
@ -132,20 +157,38 @@ fn invert_binary(
target_expr,
symbol_table,
&Expr::Binary(Box::new(left.clone()), op.clone(), inside_group.clone()),
unknown_var,
);
}
TokenKind::Star
}
_ => unreachable!(),
TokenKind::Power => {
return if contains_var(symbol_table, left, unknown_var) {
invert(
Expr::FnCall("root".into(), vec![target_expr, right.clone()]),
symbol_table,
right,
unknown_var,
)
} else {
invert(
Expr::FnCall("log".into(), vec![target_expr, left.clone()]),
symbol_table,
right,
unknown_var,
)
};
}
_ => return Err(CalcError::UnableToInvert(String::new())),
};
// If the left expression contains the unit, invert the right one instead,
// since the unit should not be moved.
if contains_the_unit(symbol_table, left) {
if contains_var(symbol_table, left, unknown_var) {
// But if the right expression *also* contains the unit,
// throw an error, since it can't handle this yet.
if contains_the_unit(symbol_table, right) {
if contains_var(symbol_table, right, unknown_var) {
return Err(CalcError::UnableToInvert(String::from(
"Expressions with several instances of an unknown variable (this might be supported in the future). Try simplifying the expression.",
)));
@ -155,6 +198,7 @@ fn invert_binary(
Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())),
symbol_table,
left,
unknown_var,
)?);
}
@ -171,6 +215,7 @@ fn invert_binary(
},
symbol_table,
right, // Then invert the right expression.
unknown_var,
)?)
}
@ -181,29 +226,35 @@ fn invert_unary(target_expr: Expr, op: &TokenKind, expr: &Expr) -> Result<(Expr,
Expr::Unary(TokenKind::Minus, Box::new(target_expr)),
expr.clone(), // And then continue inverting the inner-expression.
)),
_ => unimplemented!(),
_ => return Err(CalcError::UnableToInvert(String::new())),
}
}
fn invert_unit(
_target_expr: Expr,
_identifier: &str,
_expr: &Expr,
target_expr: Expr,
symbol_table: &mut SymbolTable,
identifier: &str,
expr: &Expr,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> {
Err(CalcError::UnableToInvert(String::from(
"Expressions containing other units (this should be supported in the future).",
)))
let x = Expr::Binary(
Box::new(target_expr),
TokenKind::ToKeyword,
Box::new(Expr::Var(identifier.into())),
);
invert(x, symbol_table, expr, unknown_var)
}
fn invert_var(
target_expr: Expr,
symbol_table: &mut SymbolTable,
identifier: &str,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> {
if identifier == DECL_UNIT {
if identifier == unknown_var {
Ok((target_expr, Expr::Var(identifier.into())))
} else if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier).cloned() {
invert(target_expr, symbol_table, &var_expr)
invert(target_expr, symbol_table, &var_expr, unknown_var)
} else {
Ok((target_expr, Expr::Var(identifier.into())))
}
@ -214,27 +265,32 @@ fn invert_fn_call(
symbol_table: &mut SymbolTable,
identifier: &str,
arguments: &Vec<Expr>,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> {
// If prelude function
match arguments.len() {
1 => {
if prelude::UNARY_FUNCS.contains_key(identifier) {
if let Some(fn_inv) = INVERSE_UNARY_FUNCS.get(identifier) {
return Ok((
return invert(
Expr::FnCall(fn_inv.to_string(), vec![target_expr]),
arguments[0].clone(),
));
symbol_table,
&arguments[0],
unknown_var,
);
} else {
match identifier {
"sqrt" => {
return Ok((
return invert(
Expr::Binary(
Box::new(target_expr),
TokenKind::Power,
Box::new(Expr::Literal(2f64)),
),
arguments[0].clone(),
));
symbol_table,
&arguments[0],
unknown_var,
);
}
_ => {
return Err(CalcError::UnableToInvert(format!(
@ -284,29 +340,30 @@ fn invert_fn_call(
}
// Invert everything in the function body.
invert(target_expr, symbol_table, &body)
invert(target_expr, symbol_table, &body, unknown_var)
}
fn contains_the_unit(symbol_table: &SymbolTable, expr: &Expr) -> bool {
// Recursively scan the expression for the unit.
pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) -> bool {
// Recursively scan the expression for the variable.
match expr {
Expr::Binary(left, _, right) => {
contains_the_unit(symbol_table, left) || contains_the_unit(symbol_table, right)
contains_var(symbol_table, left, var_name)
|| contains_var(symbol_table, right, var_name)
}
Expr::Unary(_, expr) => contains_the_unit(symbol_table, expr),
Expr::Unit(_, expr) => contains_the_unit(symbol_table, expr),
Expr::Unary(_, expr) => contains_var(symbol_table, expr, var_name),
Expr::Unit(_, expr) => contains_var(symbol_table, expr, var_name),
Expr::Var(identifier) => {
identifier == DECL_UNIT
identifier == var_name
|| if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier) {
contains_the_unit(symbol_table, var_expr)
contains_var(symbol_table, var_expr, var_name)
} else {
false
}
}
Expr::Group(expr) => contains_the_unit(symbol_table, expr),
Expr::Group(expr) => contains_var(symbol_table, expr, var_name),
Expr::FnCall(_, args) => {
for arg in args {
if contains_the_unit(symbol_table, arg) {
if contains_var(symbol_table, arg, var_name) {
return true;
}
}
@ -333,7 +390,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, CalcError> {
op.clone(),
right.clone(),
)),
_ => unimplemented!(),
_ => return Err(CalcError::UnableToInvert(String::new())),
},
// If it's a literal, just multiply them together.
Expr::Literal(_) | Expr::Var(_) => Ok(Expr::Binary(
@ -344,7 +401,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, CalcError> {
Expr::Group(_) => Err(CalcError::UnableToInvert(String::from(
"Parenthesis multiplied with parenthesis (this should be possible in the future).",
))),
_ => unimplemented!(),
_ => return Err(CalcError::UnableToInvert(String::new())),
}
}
@ -352,6 +409,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, CalcError> {
mod tests {
use crate::ast::Expr;
use crate::lexer::TokenKind::*;
use crate::parser::DECL_UNIT;
use crate::symbol_table::SymbolTable;
use crate::test_helpers::*;
@ -373,36 +431,36 @@ mod tests {
let mut symbol_table = SymbolTable::new();
assert_eq!(
ladd.invert(&mut symbol_table).unwrap(),
ladd.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Minus, literal(1f64))
);
assert_eq!(
lsub.invert(&mut symbol_table).unwrap(),
lsub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Plus, literal(1f64))
);
assert_eq!(
lmul.invert(&mut symbol_table).unwrap(),
lmul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Slash, literal(1f64))
);
assert_eq!(
ldiv.invert(&mut symbol_table).unwrap(),
ldiv.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Star, literal(1f64))
);
assert_eq!(
radd.invert(&mut symbol_table).unwrap(),
radd.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Minus, literal(1f64))
);
assert_eq!(
rsub.invert(&mut symbol_table).unwrap(),
rsub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*unary(Minus, binary(decl_unit(), Plus, literal(1f64)))
);
assert_eq!(
rmul.invert(&mut symbol_table).unwrap(),
rmul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Slash, literal(1f64))
);
assert_eq!(
rdiv.invert(&mut symbol_table).unwrap(),
rdiv.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Star, literal(1f64))
);
}
@ -412,7 +470,7 @@ mod tests {
let neg = unary(Minus, decl_unit());
let mut symbol_table = SymbolTable::new();
assert_eq!(neg.invert(&mut symbol_table).unwrap(), *neg);
assert_eq!(neg.invert(&mut symbol_table, DECL_UNIT).unwrap(), *neg);
}
#[test]
@ -430,16 +488,20 @@ mod tests {
let mut symbol_table = SymbolTable::new();
symbol_table.insert(decl);
assert_eq!(
call_with_literal.invert(&mut symbol_table).unwrap(),
call_with_literal
.invert(&mut symbol_table, DECL_UNIT)
.unwrap(),
*binary(decl_unit(), Minus, fn_call("f", vec![*literal(2f64)])),
);
assert_eq!(
call_with_decl_unit.invert(&mut symbol_table).unwrap(),
call_with_decl_unit
.invert(&mut symbol_table, DECL_UNIT)
.unwrap(),
*binary(decl_unit(), Minus, literal(1f64))
);
assert_eq!(
call_with_decl_unit_and_literal
.invert(&mut symbol_table)
.invert(&mut symbol_table, DECL_UNIT)
.unwrap(),
*binary(
binary(decl_unit(), Minus, literal(1f64)),
@ -484,7 +546,7 @@ mod tests {
let mut symbol_table = SymbolTable::new();
assert_eq!(
group_x.invert(&mut symbol_table).unwrap(),
group_x.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(
binary(
decl_unit(),
@ -496,7 +558,9 @@ mod tests {
)
);
assert_eq!(
group_unary_minus.invert(&mut symbol_table).unwrap(),
group_unary_minus
.invert(&mut symbol_table, DECL_UNIT)
.unwrap(),
*binary(
binary(
binary(decl_unit(), Minus, literal(2f64)),
@ -508,7 +572,7 @@ mod tests {
)
);
assert_eq!(
x_group_add.invert(&mut symbol_table).unwrap(),
x_group_add.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(
binary(
decl_unit(),
@ -520,7 +584,7 @@ mod tests {
)
);
assert_eq!(
x_group_sub.invert(&mut symbol_table).unwrap(),
x_group_sub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(
binary(
decl_unit(),
@ -532,7 +596,7 @@ mod tests {
)
);
assert_eq!(
x_group_mul.invert(&mut symbol_table).unwrap(),
x_group_mul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(
binary(decl_unit(), Slash, literal(3f64)),
Slash,
@ -540,7 +604,7 @@ mod tests {
)
);
assert_eq!(
x_group_div.invert(&mut symbol_table).unwrap(),
x_group_div.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(
binary(decl_unit(), Star, literal(3f64)),
Slash,

View File

@ -1,8 +1,9 @@
use crate::kalk_num::KalkNum;
use crate::{
ast::{Expr, Stmt},
interpreter,
interpreter, inverter,
lexer::{Lexer, Token, TokenKind},
prelude,
symbol_table::SymbolTable,
};
@ -31,6 +32,8 @@ pub struct Context {
/// whenever a unit in the expression is found. Eg. unit a = 3b, it will be set to Some("b")
unit_decl_base_unit: Option<String>,
parsing_identifier_stmt: bool,
equation_variable: Option<String>,
contains_equal_sign: bool,
}
impl Context {
@ -43,6 +46,8 @@ impl Context {
parsing_unit_decl: false,
unit_decl_base_unit: None,
parsing_identifier_stmt: false,
equation_variable: None,
contains_equal_sign: false,
};
parse(&mut context, crate::prelude::INIT).unwrap();
@ -74,6 +79,7 @@ pub enum CalcError {
UndefinedFn(String),
UndefinedVar(String),
UnableToInvert(String),
UnableToSolveEquation,
UnableToParseExpression,
Unknown,
}
@ -86,6 +92,7 @@ pub fn eval(
input: &str,
precision: u32,
) -> Result<Option<KalkNum>, CalcError> {
context.contains_equal_sign = input.contains("=");
let statements = parse(context, input)?;
let mut interpreter =
@ -134,13 +141,14 @@ fn parse_identifier_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
let primary = parse_primary(context)?; // Since function declarations and function calls look the same at first, simply parse a "function call", and re-use the data.
context.parsing_identifier_stmt = false;
// If `primary` is followed by an equal sign, it is a function declaration.
// If `primary` is followed by an equal sign and is not a prelude function,
// treat it as a function declaration
if let TokenKind::Equals = peek(context).kind {
advance(context);
let expr = parse_expr(context)?;
// Use the "function call" expression that was parsed, and put its values into a function declaration statement instead.
if let Expr::FnCall(identifier, parameters) = primary {
if !prelude::is_prelude_func(&identifier) {
let expr = parse_expr(context)?;
advance(context);
let mut parameter_identifiers = Vec::new();
// All the "arguments" are expected to be parsed as variables,
@ -152,7 +160,8 @@ fn parse_identifier_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
}
}
let fn_decl = Stmt::FnDecl(identifier.clone(), parameter_identifiers, Box::new(expr));
let fn_decl =
Stmt::FnDecl(identifier.clone(), parameter_identifiers, Box::new(expr));
// Insert the function declaration into the symbol table during parsing
// so that the parser can find out if particular functions exist.
@ -160,15 +169,14 @@ fn parse_identifier_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
return Ok(fn_decl);
}
}
}
Err(CalcError::Unknown)
} else {
// It is a function call or eg. x(x + 3), not a function declaration.
// Redo the parsing for this specific part.
context.pos = began_at;
Ok(Stmt::Expr(Box::new(parse_expr(context)?)))
}
}
fn parse_var_decl_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
let identifier = advance(context).clone();
@ -201,7 +209,7 @@ fn parse_unit_decl_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
let stmt_inv = Stmt::UnitDecl(
base_unit.clone(),
identifier.value.clone(),
Box::new(def.invert(&mut context.symbol_table)?),
Box::new(def.invert(&mut context.symbol_table, DECL_UNIT)?),
);
let stmt = Stmt::UnitDecl(identifier.value, base_unit, Box::new(def));
@ -212,17 +220,54 @@ fn parse_unit_decl_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
}
fn parse_expr(context: &mut Context) -> Result<Expr, CalcError> {
Ok(parse_to(context)?)
Ok(parse_equation(context)?)
}
fn parse_equation(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_to(context)?;
if match_token(context, TokenKind::Equals) {
advance(context);
let right = parse_to(context)?;
let var_name = if let Some(var_name) = &context.equation_variable {
var_name
} else {
return Err(CalcError::UnableToSolveEquation);
};
let inverted = if inverter::contains_var(&mut context.symbol_table, &left, var_name) {
left.invert_to_target(&mut context.symbol_table, right, var_name)?
} else {
right.invert_to_target(&mut context.symbol_table, left, var_name)?
};
// If the inverted expression still contains the variable,
// the equation solving failed.
if inverter::contains_var(&mut context.symbol_table, &inverted, var_name) {
return Err(CalcError::UnableToSolveEquation);
}
context
.symbol_table
.insert(Stmt::VarDecl(var_name.into(), Box::new(inverted.clone())));
return Ok(inverted);
}
Ok(left)
}
fn parse_to(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_sum(context)?;
if match_token(context, TokenKind::ToKeyword) {
let op = advance(context).kind;
advance(context);
let right = Expr::Var(advance(context).value.clone()); // Parse this as a variable for now.
return Ok(Expr::Binary(Box::new(left), op, Box::new(right)));
return Ok(Expr::Binary(
Box::new(left),
TokenKind::ToKeyword,
Box::new(right),
));
}
Ok(left)
@ -367,7 +412,7 @@ fn parse_group_fn(context: &mut Context) -> Result<Expr, CalcError> {
TokenKind::Pipe => "abs",
TokenKind::OpenCeil => "ceil",
TokenKind::OpenFloor => "floor",
_ => panic!("Unexpected parsing error."),
_ => unreachable!(),
};
let expr = parse_expr(context)?;
@ -419,6 +464,15 @@ fn parse_identifier(context: &mut Context) -> Result<Expr, CalcError> {
context.unit_decl_base_unit = Some(identifier.value);
Ok(Expr::Var(DECL_UNIT.into()))
} else {
if let Some(equation_var) = &context.equation_variable {
if &identifier.value == equation_var {
return Ok(Expr::Var(identifier.value));
}
} else if context.contains_equal_sign {
context.equation_variable = Some(identifier.value.clone());
return Ok(Expr::Var(identifier.value));
}
let mut chars = identifier.value.chars();
let mut left = Expr::Var(chars.next().unwrap().to_string());

View File

@ -117,6 +117,13 @@ impl BinaryFuncInfo {
}
}
pub fn is_prelude_func(identifier: &str) -> bool {
identifier == "sum"
|| identifier == "Σ"
|| UNARY_FUNCS.contains_key(identifier)
|| BINARY_FUNCS.contains_key(identifier)
}
pub fn call_unary_func(
context: &mut interpreter::Context,
name: &str,

View File

@ -74,10 +74,7 @@ impl SymbolTable {
}
pub fn contains_fn(&self, identifier: &str) -> bool {
identifier == "sum"
|| identifier == "Σ"
|| prelude::UNARY_FUNCS.contains_key(identifier)
|| prelude::BINARY_FUNCS.contains_key(identifier)
prelude::is_prelude_func(identifier)
|| self.hashmap.contains_key(&format!("fn.{}", identifier))
}
}

View File

@ -42,6 +42,7 @@ fn print_calc_err(err: CalcError) {
UndefinedFn(name) => format!("Undefined function: '{}'.", name),
UndefinedVar(name) => format!("Undefined variable: '{}'.", name),
UnableToParseExpression => format!("Unable to parse expression."),
UnableToSolveEquation => format!("Unable to solve equation."),
Unknown => format!("Unknown error."),
});
}