Simple equation solving, mostly using pre-existing logic from the inverter

This commit is contained in:
PaddiM8 2020-12-14 19:21:30 +01:00
parent 34927585c5
commit 2f7af7de90
5 changed files with 217 additions and 94 deletions

View File

@ -1,7 +1,6 @@
use crate::ast::{Expr, Stmt}; use crate::ast::{Expr, Stmt};
use crate::lexer::TokenKind; use crate::lexer::TokenKind;
use crate::parser::CalcError; use crate::parser::CalcError;
use crate::parser::DECL_UNIT;
use crate::prelude; use crate::prelude;
use crate::symbol_table::SymbolTable; use crate::symbol_table::SymbolTable;
use lazy_static::lazy_static; use lazy_static::lazy_static;
@ -40,30 +39,51 @@ lazy_static! {
} }
impl Expr { impl Expr {
pub fn invert(&self, symbol_table: &mut SymbolTable) -> Result<Self, CalcError> { pub fn invert(
let target_expr = Expr::Var(DECL_UNIT.into()); &self,
let result = invert(target_expr, symbol_table, self); symbol_table: &mut SymbolTable,
unknown_var: &str,
) -> Result<Self, CalcError> {
let target_expr = Expr::Var(unknown_var.into());
let result = invert(target_expr, symbol_table, self, unknown_var);
Ok(result?.0) Ok(result?.0)
} }
pub fn invert_to_target(
&self,
symbol_table: &mut SymbolTable,
target_expr: Expr,
unknown_var: &str,
) -> Result<Self, CalcError> {
let x = invert(target_expr, symbol_table, self, unknown_var)?;
Ok(x.0)
}
} }
fn invert( fn invert(
target_expr: Expr, target_expr: Expr,
symbol_table: &mut SymbolTable, symbol_table: &mut SymbolTable,
expr: &Expr, expr: &Expr,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> { ) -> Result<(Expr, Expr), CalcError> {
match expr { match expr {
Expr::Binary(left, op, right) => { Expr::Binary(left, op, right) => {
invert_binary(target_expr, symbol_table, &left, op, &right) invert_binary(target_expr, symbol_table, &left, op, &right, unknown_var)
} }
Expr::Unary(op, expr) => invert_unary(target_expr, op, &expr), Expr::Unary(op, expr) => invert_unary(target_expr, op, &expr),
Expr::Unit(identifier, expr) => invert_unit(target_expr, &identifier, &expr), Expr::Unit(identifier, expr) => {
Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier), invert_unit(target_expr, symbol_table, &identifier, &expr, unknown_var)
Expr::Group(expr) => Ok((target_expr, *expr.clone())),
Expr::FnCall(identifier, arguments) => {
invert_fn_call(target_expr, symbol_table, &identifier, arguments)
} }
Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier, unknown_var),
Expr::Group(expr) => Ok((target_expr, *expr.clone())),
Expr::FnCall(identifier, arguments) => invert_fn_call(
target_expr,
symbol_table,
&identifier,
arguments,
unknown_var,
),
Expr::Literal(_) => Ok((target_expr, expr.clone())), Expr::Literal(_) => Ok((target_expr, expr.clone())),
} }
} }
@ -74,6 +94,7 @@ fn invert_binary(
left: &Expr, left: &Expr,
op: &TokenKind, op: &TokenKind,
right: &Expr, right: &Expr,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> { ) -> Result<(Expr, Expr), CalcError> {
let op_inv = match op { let op_inv = match op {
TokenKind::Plus => TokenKind::Minus, TokenKind::Plus => TokenKind::Minus,
@ -87,6 +108,7 @@ fn invert_binary(
left, left,
&TokenKind::Plus, &TokenKind::Plus,
&multiply_into(&Expr::Literal(-1f64), inside_group)?, &multiply_into(&Expr::Literal(-1f64), inside_group)?,
unknown_var,
); );
} }
@ -100,6 +122,7 @@ fn invert_binary(
target_expr, target_expr,
symbol_table, symbol_table,
&multiply_into(right, inside_group)?, &multiply_into(right, inside_group)?,
unknown_var,
); );
} }
@ -109,6 +132,7 @@ fn invert_binary(
target_expr, target_expr,
symbol_table, symbol_table,
&multiply_into(left, inside_group)?, &multiply_into(left, inside_group)?,
unknown_var,
); );
} }
@ -122,6 +146,7 @@ fn invert_binary(
target_expr, target_expr,
symbol_table, symbol_table,
&Expr::Binary(inside_group.clone(), op.clone(), Box::new(right.clone())), &Expr::Binary(inside_group.clone(), op.clone(), Box::new(right.clone())),
unknown_var,
); );
} }
@ -132,20 +157,38 @@ fn invert_binary(
target_expr, target_expr,
symbol_table, symbol_table,
&Expr::Binary(Box::new(left.clone()), op.clone(), inside_group.clone()), &Expr::Binary(Box::new(left.clone()), op.clone(), inside_group.clone()),
unknown_var,
); );
} }
TokenKind::Star TokenKind::Star
} }
_ => unreachable!(), TokenKind::Power => {
return if contains_var(symbol_table, left, unknown_var) {
invert(
Expr::FnCall("root".into(), vec![target_expr, right.clone()]),
symbol_table,
right,
unknown_var,
)
} else {
invert(
Expr::FnCall("log".into(), vec![target_expr, left.clone()]),
symbol_table,
right,
unknown_var,
)
};
}
_ => return Err(CalcError::UnableToInvert(String::new())),
}; };
// If the left expression contains the unit, invert the right one instead, // If the left expression contains the unit, invert the right one instead,
// since the unit should not be moved. // since the unit should not be moved.
if contains_the_unit(symbol_table, left) { if contains_var(symbol_table, left, unknown_var) {
// But if the right expression *also* contains the unit, // But if the right expression *also* contains the unit,
// throw an error, since it can't handle this yet. // throw an error, since it can't handle this yet.
if contains_the_unit(symbol_table, right) { if contains_var(symbol_table, right, unknown_var) {
return Err(CalcError::UnableToInvert(String::from( return Err(CalcError::UnableToInvert(String::from(
"Expressions with several instances of an unknown variable (this might be supported in the future). Try simplifying the expression.", "Expressions with several instances of an unknown variable (this might be supported in the future). Try simplifying the expression.",
))); )));
@ -155,6 +198,7 @@ fn invert_binary(
Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())), Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())),
symbol_table, symbol_table,
left, left,
unknown_var,
)?); )?);
} }
@ -171,6 +215,7 @@ fn invert_binary(
}, },
symbol_table, symbol_table,
right, // Then invert the right expression. right, // Then invert the right expression.
unknown_var,
)?) )?)
} }
@ -181,29 +226,35 @@ fn invert_unary(target_expr: Expr, op: &TokenKind, expr: &Expr) -> Result<(Expr,
Expr::Unary(TokenKind::Minus, Box::new(target_expr)), Expr::Unary(TokenKind::Minus, Box::new(target_expr)),
expr.clone(), // And then continue inverting the inner-expression. expr.clone(), // And then continue inverting the inner-expression.
)), )),
_ => unimplemented!(), _ => return Err(CalcError::UnableToInvert(String::new())),
} }
} }
fn invert_unit( fn invert_unit(
_target_expr: Expr, target_expr: Expr,
_identifier: &str, symbol_table: &mut SymbolTable,
_expr: &Expr, identifier: &str,
expr: &Expr,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> { ) -> Result<(Expr, Expr), CalcError> {
Err(CalcError::UnableToInvert(String::from( let x = Expr::Binary(
"Expressions containing other units (this should be supported in the future).", Box::new(target_expr),
))) TokenKind::ToKeyword,
Box::new(Expr::Var(identifier.into())),
);
invert(x, symbol_table, expr, unknown_var)
} }
fn invert_var( fn invert_var(
target_expr: Expr, target_expr: Expr,
symbol_table: &mut SymbolTable, symbol_table: &mut SymbolTable,
identifier: &str, identifier: &str,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> { ) -> Result<(Expr, Expr), CalcError> {
if identifier == DECL_UNIT { if identifier == unknown_var {
Ok((target_expr, Expr::Var(identifier.into()))) Ok((target_expr, Expr::Var(identifier.into())))
} else if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier).cloned() { } else if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier).cloned() {
invert(target_expr, symbol_table, &var_expr) invert(target_expr, symbol_table, &var_expr, unknown_var)
} else { } else {
Ok((target_expr, Expr::Var(identifier.into()))) Ok((target_expr, Expr::Var(identifier.into())))
} }
@ -214,27 +265,32 @@ fn invert_fn_call(
symbol_table: &mut SymbolTable, symbol_table: &mut SymbolTable,
identifier: &str, identifier: &str,
arguments: &Vec<Expr>, arguments: &Vec<Expr>,
unknown_var: &str,
) -> Result<(Expr, Expr), CalcError> { ) -> Result<(Expr, Expr), CalcError> {
// If prelude function // If prelude function
match arguments.len() { match arguments.len() {
1 => { 1 => {
if prelude::UNARY_FUNCS.contains_key(identifier) { if prelude::UNARY_FUNCS.contains_key(identifier) {
if let Some(fn_inv) = INVERSE_UNARY_FUNCS.get(identifier) { if let Some(fn_inv) = INVERSE_UNARY_FUNCS.get(identifier) {
return Ok(( return invert(
Expr::FnCall(fn_inv.to_string(), vec![target_expr]), Expr::FnCall(fn_inv.to_string(), vec![target_expr]),
arguments[0].clone(), symbol_table,
)); &arguments[0],
unknown_var,
);
} else { } else {
match identifier { match identifier {
"sqrt" => { "sqrt" => {
return Ok(( return invert(
Expr::Binary( Expr::Binary(
Box::new(target_expr), Box::new(target_expr),
TokenKind::Power, TokenKind::Power,
Box::new(Expr::Literal(2f64)), Box::new(Expr::Literal(2f64)),
), ),
arguments[0].clone(), symbol_table,
)); &arguments[0],
unknown_var,
);
} }
_ => { _ => {
return Err(CalcError::UnableToInvert(format!( return Err(CalcError::UnableToInvert(format!(
@ -284,29 +340,30 @@ fn invert_fn_call(
} }
// Invert everything in the function body. // Invert everything in the function body.
invert(target_expr, symbol_table, &body) invert(target_expr, symbol_table, &body, unknown_var)
} }
fn contains_the_unit(symbol_table: &SymbolTable, expr: &Expr) -> bool { pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) -> bool {
// Recursively scan the expression for the unit. // Recursively scan the expression for the variable.
match expr { match expr {
Expr::Binary(left, _, right) => { Expr::Binary(left, _, right) => {
contains_the_unit(symbol_table, left) || contains_the_unit(symbol_table, right) contains_var(symbol_table, left, var_name)
|| contains_var(symbol_table, right, var_name)
} }
Expr::Unary(_, expr) => contains_the_unit(symbol_table, expr), Expr::Unary(_, expr) => contains_var(symbol_table, expr, var_name),
Expr::Unit(_, expr) => contains_the_unit(symbol_table, expr), Expr::Unit(_, expr) => contains_var(symbol_table, expr, var_name),
Expr::Var(identifier) => { Expr::Var(identifier) => {
identifier == DECL_UNIT identifier == var_name
|| if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier) { || if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier) {
contains_the_unit(symbol_table, var_expr) contains_var(symbol_table, var_expr, var_name)
} else { } else {
false false
} }
} }
Expr::Group(expr) => contains_the_unit(symbol_table, expr), Expr::Group(expr) => contains_var(symbol_table, expr, var_name),
Expr::FnCall(_, args) => { Expr::FnCall(_, args) => {
for arg in args { for arg in args {
if contains_the_unit(symbol_table, arg) { if contains_var(symbol_table, arg, var_name) {
return true; return true;
} }
} }
@ -333,7 +390,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, CalcError> {
op.clone(), op.clone(),
right.clone(), right.clone(),
)), )),
_ => unimplemented!(), _ => return Err(CalcError::UnableToInvert(String::new())),
}, },
// If it's a literal, just multiply them together. // If it's a literal, just multiply them together.
Expr::Literal(_) | Expr::Var(_) => Ok(Expr::Binary( Expr::Literal(_) | Expr::Var(_) => Ok(Expr::Binary(
@ -344,7 +401,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, CalcError> {
Expr::Group(_) => Err(CalcError::UnableToInvert(String::from( Expr::Group(_) => Err(CalcError::UnableToInvert(String::from(
"Parenthesis multiplied with parenthesis (this should be possible in the future).", "Parenthesis multiplied with parenthesis (this should be possible in the future).",
))), ))),
_ => unimplemented!(), _ => return Err(CalcError::UnableToInvert(String::new())),
} }
} }
@ -352,6 +409,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result<Expr, CalcError> {
mod tests { mod tests {
use crate::ast::Expr; use crate::ast::Expr;
use crate::lexer::TokenKind::*; use crate::lexer::TokenKind::*;
use crate::parser::DECL_UNIT;
use crate::symbol_table::SymbolTable; use crate::symbol_table::SymbolTable;
use crate::test_helpers::*; use crate::test_helpers::*;
@ -373,36 +431,36 @@ mod tests {
let mut symbol_table = SymbolTable::new(); let mut symbol_table = SymbolTable::new();
assert_eq!( assert_eq!(
ladd.invert(&mut symbol_table).unwrap(), ladd.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Minus, literal(1f64)) *binary(decl_unit(), Minus, literal(1f64))
); );
assert_eq!( assert_eq!(
lsub.invert(&mut symbol_table).unwrap(), lsub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Plus, literal(1f64)) *binary(decl_unit(), Plus, literal(1f64))
); );
assert_eq!( assert_eq!(
lmul.invert(&mut symbol_table).unwrap(), lmul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Slash, literal(1f64)) *binary(decl_unit(), Slash, literal(1f64))
); );
assert_eq!( assert_eq!(
ldiv.invert(&mut symbol_table).unwrap(), ldiv.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Star, literal(1f64)) *binary(decl_unit(), Star, literal(1f64))
); );
assert_eq!( assert_eq!(
radd.invert(&mut symbol_table).unwrap(), radd.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Minus, literal(1f64)) *binary(decl_unit(), Minus, literal(1f64))
); );
assert_eq!( assert_eq!(
rsub.invert(&mut symbol_table).unwrap(), rsub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*unary(Minus, binary(decl_unit(), Plus, literal(1f64))) *unary(Minus, binary(decl_unit(), Plus, literal(1f64)))
); );
assert_eq!( assert_eq!(
rmul.invert(&mut symbol_table).unwrap(), rmul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Slash, literal(1f64)) *binary(decl_unit(), Slash, literal(1f64))
); );
assert_eq!( assert_eq!(
rdiv.invert(&mut symbol_table).unwrap(), rdiv.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary(decl_unit(), Star, literal(1f64)) *binary(decl_unit(), Star, literal(1f64))
); );
} }
@ -412,7 +470,7 @@ mod tests {
let neg = unary(Minus, decl_unit()); let neg = unary(Minus, decl_unit());
let mut symbol_table = SymbolTable::new(); let mut symbol_table = SymbolTable::new();
assert_eq!(neg.invert(&mut symbol_table).unwrap(), *neg); assert_eq!(neg.invert(&mut symbol_table, DECL_UNIT).unwrap(), *neg);
} }
#[test] #[test]
@ -430,16 +488,20 @@ mod tests {
let mut symbol_table = SymbolTable::new(); let mut symbol_table = SymbolTable::new();
symbol_table.insert(decl); symbol_table.insert(decl);
assert_eq!( assert_eq!(
call_with_literal.invert(&mut symbol_table).unwrap(), call_with_literal
.invert(&mut symbol_table, DECL_UNIT)
.unwrap(),
*binary(decl_unit(), Minus, fn_call("f", vec![*literal(2f64)])), *binary(decl_unit(), Minus, fn_call("f", vec![*literal(2f64)])),
); );
assert_eq!( assert_eq!(
call_with_decl_unit.invert(&mut symbol_table).unwrap(), call_with_decl_unit
.invert(&mut symbol_table, DECL_UNIT)
.unwrap(),
*binary(decl_unit(), Minus, literal(1f64)) *binary(decl_unit(), Minus, literal(1f64))
); );
assert_eq!( assert_eq!(
call_with_decl_unit_and_literal call_with_decl_unit_and_literal
.invert(&mut symbol_table) .invert(&mut symbol_table, DECL_UNIT)
.unwrap(), .unwrap(),
*binary( *binary(
binary(decl_unit(), Minus, literal(1f64)), binary(decl_unit(), Minus, literal(1f64)),
@ -484,7 +546,7 @@ mod tests {
let mut symbol_table = SymbolTable::new(); let mut symbol_table = SymbolTable::new();
assert_eq!( assert_eq!(
group_x.invert(&mut symbol_table).unwrap(), group_x.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary( *binary(
binary( binary(
decl_unit(), decl_unit(),
@ -496,7 +558,9 @@ mod tests {
) )
); );
assert_eq!( assert_eq!(
group_unary_minus.invert(&mut symbol_table).unwrap(), group_unary_minus
.invert(&mut symbol_table, DECL_UNIT)
.unwrap(),
*binary( *binary(
binary( binary(
binary(decl_unit(), Minus, literal(2f64)), binary(decl_unit(), Minus, literal(2f64)),
@ -508,7 +572,7 @@ mod tests {
) )
); );
assert_eq!( assert_eq!(
x_group_add.invert(&mut symbol_table).unwrap(), x_group_add.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary( *binary(
binary( binary(
decl_unit(), decl_unit(),
@ -520,7 +584,7 @@ mod tests {
) )
); );
assert_eq!( assert_eq!(
x_group_sub.invert(&mut symbol_table).unwrap(), x_group_sub.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary( *binary(
binary( binary(
decl_unit(), decl_unit(),
@ -532,7 +596,7 @@ mod tests {
) )
); );
assert_eq!( assert_eq!(
x_group_mul.invert(&mut symbol_table).unwrap(), x_group_mul.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary( *binary(
binary(decl_unit(), Slash, literal(3f64)), binary(decl_unit(), Slash, literal(3f64)),
Slash, Slash,
@ -540,7 +604,7 @@ mod tests {
) )
); );
assert_eq!( assert_eq!(
x_group_div.invert(&mut symbol_table).unwrap(), x_group_div.invert(&mut symbol_table, DECL_UNIT).unwrap(),
*binary( *binary(
binary(decl_unit(), Star, literal(3f64)), binary(decl_unit(), Star, literal(3f64)),
Slash, Slash,

View File

@ -1,8 +1,9 @@
use crate::kalk_num::KalkNum; use crate::kalk_num::KalkNum;
use crate::{ use crate::{
ast::{Expr, Stmt}, ast::{Expr, Stmt},
interpreter, interpreter, inverter,
lexer::{Lexer, Token, TokenKind}, lexer::{Lexer, Token, TokenKind},
prelude,
symbol_table::SymbolTable, symbol_table::SymbolTable,
}; };
@ -31,6 +32,8 @@ pub struct Context {
/// whenever a unit in the expression is found. Eg. unit a = 3b, it will be set to Some("b") /// whenever a unit in the expression is found. Eg. unit a = 3b, it will be set to Some("b")
unit_decl_base_unit: Option<String>, unit_decl_base_unit: Option<String>,
parsing_identifier_stmt: bool, parsing_identifier_stmt: bool,
equation_variable: Option<String>,
contains_equal_sign: bool,
} }
impl Context { impl Context {
@ -43,6 +46,8 @@ impl Context {
parsing_unit_decl: false, parsing_unit_decl: false,
unit_decl_base_unit: None, unit_decl_base_unit: None,
parsing_identifier_stmt: false, parsing_identifier_stmt: false,
equation_variable: None,
contains_equal_sign: false,
}; };
parse(&mut context, crate::prelude::INIT).unwrap(); parse(&mut context, crate::prelude::INIT).unwrap();
@ -74,6 +79,7 @@ pub enum CalcError {
UndefinedFn(String), UndefinedFn(String),
UndefinedVar(String), UndefinedVar(String),
UnableToInvert(String), UnableToInvert(String),
UnableToSolveEquation,
UnableToParseExpression, UnableToParseExpression,
Unknown, Unknown,
} }
@ -86,6 +92,7 @@ pub fn eval(
input: &str, input: &str,
precision: u32, precision: u32,
) -> Result<Option<KalkNum>, CalcError> { ) -> Result<Option<KalkNum>, CalcError> {
context.contains_equal_sign = input.contains("=");
let statements = parse(context, input)?; let statements = parse(context, input)?;
let mut interpreter = let mut interpreter =
@ -134,40 +141,41 @@ fn parse_identifier_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
let primary = parse_primary(context)?; // Since function declarations and function calls look the same at first, simply parse a "function call", and re-use the data. let primary = parse_primary(context)?; // Since function declarations and function calls look the same at first, simply parse a "function call", and re-use the data.
context.parsing_identifier_stmt = false; context.parsing_identifier_stmt = false;
// If `primary` is followed by an equal sign, it is a function declaration. // If `primary` is followed by an equal sign and is not a prelude function,
// treat it as a function declaration
if let TokenKind::Equals = peek(context).kind { if let TokenKind::Equals = peek(context).kind {
advance(context);
let expr = parse_expr(context)?;
// Use the "function call" expression that was parsed, and put its values into a function declaration statement instead. // Use the "function call" expression that was parsed, and put its values into a function declaration statement instead.
if let Expr::FnCall(identifier, parameters) = primary { if let Expr::FnCall(identifier, parameters) = primary {
let mut parameter_identifiers = Vec::new(); if !prelude::is_prelude_func(&identifier) {
let expr = parse_expr(context)?;
advance(context);
let mut parameter_identifiers = Vec::new();
// All the "arguments" are expected to be parsed as variables, // All the "arguments" are expected to be parsed as variables,
// since parameter definitions look the same as variable references. // since parameter definitions look the same as variable references.
// Extract these. // Extract these.
for parameter in parameters { for parameter in parameters {
if let Expr::Var(parameter_identifier) = parameter { if let Expr::Var(parameter_identifier) = parameter {
parameter_identifiers.push(parameter_identifier); parameter_identifiers.push(parameter_identifier);
}
} }
let fn_decl =
Stmt::FnDecl(identifier.clone(), parameter_identifiers, Box::new(expr));
// Insert the function declaration into the symbol table during parsing
// so that the parser can find out if particular functions exist.
context.symbol_table.insert(fn_decl.clone());
return Ok(fn_decl);
} }
let fn_decl = Stmt::FnDecl(identifier.clone(), parameter_identifiers, Box::new(expr));
// Insert the function declaration into the symbol table during parsing
// so that the parser can find out if particular functions exist.
context.symbol_table.insert(fn_decl.clone());
return Ok(fn_decl);
} }
Err(CalcError::Unknown)
} else {
// It is a function call or eg. x(x + 3), not a function declaration.
// Redo the parsing for this specific part.
context.pos = began_at;
Ok(Stmt::Expr(Box::new(parse_expr(context)?)))
} }
// It is a function call or eg. x(x + 3), not a function declaration.
// Redo the parsing for this specific part.
context.pos = began_at;
Ok(Stmt::Expr(Box::new(parse_expr(context)?)))
} }
fn parse_var_decl_stmt(context: &mut Context) -> Result<Stmt, CalcError> { fn parse_var_decl_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
@ -201,7 +209,7 @@ fn parse_unit_decl_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
let stmt_inv = Stmt::UnitDecl( let stmt_inv = Stmt::UnitDecl(
base_unit.clone(), base_unit.clone(),
identifier.value.clone(), identifier.value.clone(),
Box::new(def.invert(&mut context.symbol_table)?), Box::new(def.invert(&mut context.symbol_table, DECL_UNIT)?),
); );
let stmt = Stmt::UnitDecl(identifier.value, base_unit, Box::new(def)); let stmt = Stmt::UnitDecl(identifier.value, base_unit, Box::new(def));
@ -212,17 +220,54 @@ fn parse_unit_decl_stmt(context: &mut Context) -> Result<Stmt, CalcError> {
} }
fn parse_expr(context: &mut Context) -> Result<Expr, CalcError> { fn parse_expr(context: &mut Context) -> Result<Expr, CalcError> {
Ok(parse_to(context)?) Ok(parse_equation(context)?)
}
fn parse_equation(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_to(context)?;
if match_token(context, TokenKind::Equals) {
advance(context);
let right = parse_to(context)?;
let var_name = if let Some(var_name) = &context.equation_variable {
var_name
} else {
return Err(CalcError::UnableToSolveEquation);
};
let inverted = if inverter::contains_var(&mut context.symbol_table, &left, var_name) {
left.invert_to_target(&mut context.symbol_table, right, var_name)?
} else {
right.invert_to_target(&mut context.symbol_table, left, var_name)?
};
// If the inverted expression still contains the variable,
// the equation solving failed.
if inverter::contains_var(&mut context.symbol_table, &inverted, var_name) {
return Err(CalcError::UnableToSolveEquation);
}
context
.symbol_table
.insert(Stmt::VarDecl(var_name.into(), Box::new(inverted.clone())));
return Ok(inverted);
}
Ok(left)
} }
fn parse_to(context: &mut Context) -> Result<Expr, CalcError> { fn parse_to(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_sum(context)?; let left = parse_sum(context)?;
if match_token(context, TokenKind::ToKeyword) { if match_token(context, TokenKind::ToKeyword) {
let op = advance(context).kind; advance(context);
let right = Expr::Var(advance(context).value.clone()); // Parse this as a variable for now. let right = Expr::Var(advance(context).value.clone()); // Parse this as a variable for now.
return Ok(Expr::Binary(Box::new(left), op, Box::new(right))); return Ok(Expr::Binary(
Box::new(left),
TokenKind::ToKeyword,
Box::new(right),
));
} }
Ok(left) Ok(left)
@ -367,7 +412,7 @@ fn parse_group_fn(context: &mut Context) -> Result<Expr, CalcError> {
TokenKind::Pipe => "abs", TokenKind::Pipe => "abs",
TokenKind::OpenCeil => "ceil", TokenKind::OpenCeil => "ceil",
TokenKind::OpenFloor => "floor", TokenKind::OpenFloor => "floor",
_ => panic!("Unexpected parsing error."), _ => unreachable!(),
}; };
let expr = parse_expr(context)?; let expr = parse_expr(context)?;
@ -419,6 +464,15 @@ fn parse_identifier(context: &mut Context) -> Result<Expr, CalcError> {
context.unit_decl_base_unit = Some(identifier.value); context.unit_decl_base_unit = Some(identifier.value);
Ok(Expr::Var(DECL_UNIT.into())) Ok(Expr::Var(DECL_UNIT.into()))
} else { } else {
if let Some(equation_var) = &context.equation_variable {
if &identifier.value == equation_var {
return Ok(Expr::Var(identifier.value));
}
} else if context.contains_equal_sign {
context.equation_variable = Some(identifier.value.clone());
return Ok(Expr::Var(identifier.value));
}
let mut chars = identifier.value.chars(); let mut chars = identifier.value.chars();
let mut left = Expr::Var(chars.next().unwrap().to_string()); let mut left = Expr::Var(chars.next().unwrap().to_string());

View File

@ -117,6 +117,13 @@ impl BinaryFuncInfo {
} }
} }
pub fn is_prelude_func(identifier: &str) -> bool {
identifier == "sum"
|| identifier == "Σ"
|| UNARY_FUNCS.contains_key(identifier)
|| BINARY_FUNCS.contains_key(identifier)
}
pub fn call_unary_func( pub fn call_unary_func(
context: &mut interpreter::Context, context: &mut interpreter::Context,
name: &str, name: &str,

View File

@ -74,10 +74,7 @@ impl SymbolTable {
} }
pub fn contains_fn(&self, identifier: &str) -> bool { pub fn contains_fn(&self, identifier: &str) -> bool {
identifier == "sum" prelude::is_prelude_func(identifier)
|| identifier == "Σ"
|| prelude::UNARY_FUNCS.contains_key(identifier)
|| prelude::BINARY_FUNCS.contains_key(identifier)
|| self.hashmap.contains_key(&format!("fn.{}", identifier)) || self.hashmap.contains_key(&format!("fn.{}", identifier))
} }
} }

View File

@ -42,6 +42,7 @@ fn print_calc_err(err: CalcError) {
UndefinedFn(name) => format!("Undefined function: '{}'.", name), UndefinedFn(name) => format!("Undefined function: '{}'.", name),
UndefinedVar(name) => format!("Undefined variable: '{}'.", name), UndefinedVar(name) => format!("Undefined variable: '{}'.", name),
UnableToParseExpression => format!("Unable to parse expression."), UnableToParseExpression => format!("Unable to parse expression."),
UnableToSolveEquation => format!("Unable to solve equation."),
Unknown => format!("Unknown error."), Unknown => format!("Unknown error."),
}); });
} }