diff --git a/kalk/src/inverter.rs b/kalk/src/inverter.rs index fc56ae5..4c2dd37 100644 --- a/kalk/src/inverter.rs +++ b/kalk/src/inverter.rs @@ -1,7 +1,6 @@ use crate::ast::{Expr, Stmt}; use crate::lexer::TokenKind; use crate::parser::CalcError; -use crate::parser::DECL_UNIT; use crate::prelude; use crate::symbol_table::SymbolTable; use lazy_static::lazy_static; @@ -40,30 +39,51 @@ lazy_static! { } impl Expr { - pub fn invert(&self, symbol_table: &mut SymbolTable) -> Result { - let target_expr = Expr::Var(DECL_UNIT.into()); - let result = invert(target_expr, symbol_table, self); + pub fn invert( + &self, + symbol_table: &mut SymbolTable, + unknown_var: &str, + ) -> Result { + let target_expr = Expr::Var(unknown_var.into()); + let result = invert(target_expr, symbol_table, self, unknown_var); Ok(result?.0) } + + pub fn invert_to_target( + &self, + symbol_table: &mut SymbolTable, + target_expr: Expr, + unknown_var: &str, + ) -> Result { + let x = invert(target_expr, symbol_table, self, unknown_var)?; + Ok(x.0) + } } fn invert( target_expr: Expr, symbol_table: &mut SymbolTable, expr: &Expr, + unknown_var: &str, ) -> Result<(Expr, Expr), CalcError> { match expr { Expr::Binary(left, op, right) => { - invert_binary(target_expr, symbol_table, &left, op, &right) + invert_binary(target_expr, symbol_table, &left, op, &right, unknown_var) } Expr::Unary(op, expr) => invert_unary(target_expr, op, &expr), - Expr::Unit(identifier, expr) => invert_unit(target_expr, &identifier, &expr), - Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier), - Expr::Group(expr) => Ok((target_expr, *expr.clone())), - Expr::FnCall(identifier, arguments) => { - invert_fn_call(target_expr, symbol_table, &identifier, arguments) + Expr::Unit(identifier, expr) => { + invert_unit(target_expr, symbol_table, &identifier, &expr, unknown_var) } + Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier, unknown_var), + Expr::Group(expr) => Ok((target_expr, *expr.clone())), + Expr::FnCall(identifier, arguments) => invert_fn_call( + target_expr, + symbol_table, + &identifier, + arguments, + unknown_var, + ), Expr::Literal(_) => Ok((target_expr, expr.clone())), } } @@ -74,6 +94,7 @@ fn invert_binary( left: &Expr, op: &TokenKind, right: &Expr, + unknown_var: &str, ) -> Result<(Expr, Expr), CalcError> { let op_inv = match op { TokenKind::Plus => TokenKind::Minus, @@ -87,6 +108,7 @@ fn invert_binary( left, &TokenKind::Plus, &multiply_into(&Expr::Literal(-1f64), inside_group)?, + unknown_var, ); } @@ -100,6 +122,7 @@ fn invert_binary( target_expr, symbol_table, &multiply_into(right, inside_group)?, + unknown_var, ); } @@ -109,6 +132,7 @@ fn invert_binary( target_expr, symbol_table, &multiply_into(left, inside_group)?, + unknown_var, ); } @@ -122,6 +146,7 @@ fn invert_binary( target_expr, symbol_table, &Expr::Binary(inside_group.clone(), op.clone(), Box::new(right.clone())), + unknown_var, ); } @@ -132,20 +157,38 @@ fn invert_binary( target_expr, symbol_table, &Expr::Binary(Box::new(left.clone()), op.clone(), inside_group.clone()), + unknown_var, ); } TokenKind::Star } - _ => unreachable!(), + TokenKind::Power => { + return if contains_var(symbol_table, left, unknown_var) { + invert( + Expr::FnCall("root".into(), vec![target_expr, right.clone()]), + symbol_table, + right, + unknown_var, + ) + } else { + invert( + Expr::FnCall("log".into(), vec![target_expr, left.clone()]), + symbol_table, + right, + unknown_var, + ) + }; + } + _ => return Err(CalcError::UnableToInvert(String::new())), }; // If the left expression contains the unit, invert the right one instead, // since the unit should not be moved. - if contains_the_unit(symbol_table, left) { + if contains_var(symbol_table, left, unknown_var) { // But if the right expression *also* contains the unit, // throw an error, since it can't handle this yet. - if contains_the_unit(symbol_table, right) { + if contains_var(symbol_table, right, unknown_var) { return Err(CalcError::UnableToInvert(String::from( "Expressions with several instances of an unknown variable (this might be supported in the future). Try simplifying the expression.", ))); @@ -155,6 +198,7 @@ fn invert_binary( Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())), symbol_table, left, + unknown_var, )?); } @@ -171,6 +215,7 @@ fn invert_binary( }, symbol_table, right, // Then invert the right expression. + unknown_var, )?) } @@ -181,29 +226,35 @@ fn invert_unary(target_expr: Expr, op: &TokenKind, expr: &Expr) -> Result<(Expr, Expr::Unary(TokenKind::Minus, Box::new(target_expr)), expr.clone(), // And then continue inverting the inner-expression. )), - _ => unimplemented!(), + _ => return Err(CalcError::UnableToInvert(String::new())), } } fn invert_unit( - _target_expr: Expr, - _identifier: &str, - _expr: &Expr, + target_expr: Expr, + symbol_table: &mut SymbolTable, + identifier: &str, + expr: &Expr, + unknown_var: &str, ) -> Result<(Expr, Expr), CalcError> { - Err(CalcError::UnableToInvert(String::from( - "Expressions containing other units (this should be supported in the future).", - ))) + let x = Expr::Binary( + Box::new(target_expr), + TokenKind::ToKeyword, + Box::new(Expr::Var(identifier.into())), + ); + invert(x, symbol_table, expr, unknown_var) } fn invert_var( target_expr: Expr, symbol_table: &mut SymbolTable, identifier: &str, + unknown_var: &str, ) -> Result<(Expr, Expr), CalcError> { - if identifier == DECL_UNIT { + if identifier == unknown_var { Ok((target_expr, Expr::Var(identifier.into()))) } else if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier).cloned() { - invert(target_expr, symbol_table, &var_expr) + invert(target_expr, symbol_table, &var_expr, unknown_var) } else { Ok((target_expr, Expr::Var(identifier.into()))) } @@ -214,27 +265,32 @@ fn invert_fn_call( symbol_table: &mut SymbolTable, identifier: &str, arguments: &Vec, + unknown_var: &str, ) -> Result<(Expr, Expr), CalcError> { // If prelude function match arguments.len() { 1 => { if prelude::UNARY_FUNCS.contains_key(identifier) { if let Some(fn_inv) = INVERSE_UNARY_FUNCS.get(identifier) { - return Ok(( + return invert( Expr::FnCall(fn_inv.to_string(), vec![target_expr]), - arguments[0].clone(), - )); + symbol_table, + &arguments[0], + unknown_var, + ); } else { match identifier { "sqrt" => { - return Ok(( + return invert( Expr::Binary( Box::new(target_expr), TokenKind::Power, Box::new(Expr::Literal(2f64)), ), - arguments[0].clone(), - )); + symbol_table, + &arguments[0], + unknown_var, + ); } _ => { return Err(CalcError::UnableToInvert(format!( @@ -284,29 +340,30 @@ fn invert_fn_call( } // Invert everything in the function body. - invert(target_expr, symbol_table, &body) + invert(target_expr, symbol_table, &body, unknown_var) } -fn contains_the_unit(symbol_table: &SymbolTable, expr: &Expr) -> bool { - // Recursively scan the expression for the unit. +pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) -> bool { + // Recursively scan the expression for the variable. match expr { Expr::Binary(left, _, right) => { - contains_the_unit(symbol_table, left) || contains_the_unit(symbol_table, right) + contains_var(symbol_table, left, var_name) + || contains_var(symbol_table, right, var_name) } - Expr::Unary(_, expr) => contains_the_unit(symbol_table, expr), - Expr::Unit(_, expr) => contains_the_unit(symbol_table, expr), + Expr::Unary(_, expr) => contains_var(symbol_table, expr, var_name), + Expr::Unit(_, expr) => contains_var(symbol_table, expr, var_name), Expr::Var(identifier) => { - identifier == DECL_UNIT + identifier == var_name || if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier) { - contains_the_unit(symbol_table, var_expr) + contains_var(symbol_table, var_expr, var_name) } else { false } } - Expr::Group(expr) => contains_the_unit(symbol_table, expr), + Expr::Group(expr) => contains_var(symbol_table, expr, var_name), Expr::FnCall(_, args) => { for arg in args { - if contains_the_unit(symbol_table, arg) { + if contains_var(symbol_table, arg, var_name) { return true; } } @@ -333,7 +390,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result { op.clone(), right.clone(), )), - _ => unimplemented!(), + _ => return Err(CalcError::UnableToInvert(String::new())), }, // If it's a literal, just multiply them together. Expr::Literal(_) | Expr::Var(_) => Ok(Expr::Binary( @@ -344,7 +401,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result { Expr::Group(_) => Err(CalcError::UnableToInvert(String::from( "Parenthesis multiplied with parenthesis (this should be possible in the future).", ))), - _ => unimplemented!(), + _ => return Err(CalcError::UnableToInvert(String::new())), } } @@ -352,6 +409,7 @@ fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result { mod tests { use crate::ast::Expr; use crate::lexer::TokenKind::*; + use crate::parser::DECL_UNIT; use crate::symbol_table::SymbolTable; use crate::test_helpers::*; @@ -373,36 +431,36 @@ mod tests { let mut symbol_table = SymbolTable::new(); assert_eq!( - ladd.invert(&mut symbol_table).unwrap(), + ladd.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary(decl_unit(), Minus, literal(1f64)) ); assert_eq!( - lsub.invert(&mut symbol_table).unwrap(), + lsub.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary(decl_unit(), Plus, literal(1f64)) ); assert_eq!( - lmul.invert(&mut symbol_table).unwrap(), + lmul.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary(decl_unit(), Slash, literal(1f64)) ); assert_eq!( - ldiv.invert(&mut symbol_table).unwrap(), + ldiv.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary(decl_unit(), Star, literal(1f64)) ); assert_eq!( - radd.invert(&mut symbol_table).unwrap(), + radd.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary(decl_unit(), Minus, literal(1f64)) ); assert_eq!( - rsub.invert(&mut symbol_table).unwrap(), + rsub.invert(&mut symbol_table, DECL_UNIT).unwrap(), *unary(Minus, binary(decl_unit(), Plus, literal(1f64))) ); assert_eq!( - rmul.invert(&mut symbol_table).unwrap(), + rmul.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary(decl_unit(), Slash, literal(1f64)) ); assert_eq!( - rdiv.invert(&mut symbol_table).unwrap(), + rdiv.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary(decl_unit(), Star, literal(1f64)) ); } @@ -412,7 +470,7 @@ mod tests { let neg = unary(Minus, decl_unit()); let mut symbol_table = SymbolTable::new(); - assert_eq!(neg.invert(&mut symbol_table).unwrap(), *neg); + assert_eq!(neg.invert(&mut symbol_table, DECL_UNIT).unwrap(), *neg); } #[test] @@ -430,16 +488,20 @@ mod tests { let mut symbol_table = SymbolTable::new(); symbol_table.insert(decl); assert_eq!( - call_with_literal.invert(&mut symbol_table).unwrap(), + call_with_literal + .invert(&mut symbol_table, DECL_UNIT) + .unwrap(), *binary(decl_unit(), Minus, fn_call("f", vec![*literal(2f64)])), ); assert_eq!( - call_with_decl_unit.invert(&mut symbol_table).unwrap(), + call_with_decl_unit + .invert(&mut symbol_table, DECL_UNIT) + .unwrap(), *binary(decl_unit(), Minus, literal(1f64)) ); assert_eq!( call_with_decl_unit_and_literal - .invert(&mut symbol_table) + .invert(&mut symbol_table, DECL_UNIT) .unwrap(), *binary( binary(decl_unit(), Minus, literal(1f64)), @@ -484,7 +546,7 @@ mod tests { let mut symbol_table = SymbolTable::new(); assert_eq!( - group_x.invert(&mut symbol_table).unwrap(), + group_x.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary( binary( decl_unit(), @@ -496,7 +558,9 @@ mod tests { ) ); assert_eq!( - group_unary_minus.invert(&mut symbol_table).unwrap(), + group_unary_minus + .invert(&mut symbol_table, DECL_UNIT) + .unwrap(), *binary( binary( binary(decl_unit(), Minus, literal(2f64)), @@ -508,7 +572,7 @@ mod tests { ) ); assert_eq!( - x_group_add.invert(&mut symbol_table).unwrap(), + x_group_add.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary( binary( decl_unit(), @@ -520,7 +584,7 @@ mod tests { ) ); assert_eq!( - x_group_sub.invert(&mut symbol_table).unwrap(), + x_group_sub.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary( binary( decl_unit(), @@ -532,7 +596,7 @@ mod tests { ) ); assert_eq!( - x_group_mul.invert(&mut symbol_table).unwrap(), + x_group_mul.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary( binary(decl_unit(), Slash, literal(3f64)), Slash, @@ -540,7 +604,7 @@ mod tests { ) ); assert_eq!( - x_group_div.invert(&mut symbol_table).unwrap(), + x_group_div.invert(&mut symbol_table, DECL_UNIT).unwrap(), *binary( binary(decl_unit(), Star, literal(3f64)), Slash, diff --git a/kalk/src/parser.rs b/kalk/src/parser.rs index 7f474b4..cf38dfd 100644 --- a/kalk/src/parser.rs +++ b/kalk/src/parser.rs @@ -1,8 +1,9 @@ use crate::kalk_num::KalkNum; use crate::{ ast::{Expr, Stmt}, - interpreter, + interpreter, inverter, lexer::{Lexer, Token, TokenKind}, + prelude, symbol_table::SymbolTable, }; @@ -31,6 +32,8 @@ pub struct Context { /// whenever a unit in the expression is found. Eg. unit a = 3b, it will be set to Some("b") unit_decl_base_unit: Option, parsing_identifier_stmt: bool, + equation_variable: Option, + contains_equal_sign: bool, } impl Context { @@ -43,6 +46,8 @@ impl Context { parsing_unit_decl: false, unit_decl_base_unit: None, parsing_identifier_stmt: false, + equation_variable: None, + contains_equal_sign: false, }; parse(&mut context, crate::prelude::INIT).unwrap(); @@ -74,6 +79,7 @@ pub enum CalcError { UndefinedFn(String), UndefinedVar(String), UnableToInvert(String), + UnableToSolveEquation, UnableToParseExpression, Unknown, } @@ -86,6 +92,7 @@ pub fn eval( input: &str, precision: u32, ) -> Result, CalcError> { + context.contains_equal_sign = input.contains("="); let statements = parse(context, input)?; let mut interpreter = @@ -134,40 +141,41 @@ fn parse_identifier_stmt(context: &mut Context) -> Result { let primary = parse_primary(context)?; // Since function declarations and function calls look the same at first, simply parse a "function call", and re-use the data. context.parsing_identifier_stmt = false; - // If `primary` is followed by an equal sign, it is a function declaration. + // If `primary` is followed by an equal sign and is not a prelude function, + // treat it as a function declaration if let TokenKind::Equals = peek(context).kind { - advance(context); - let expr = parse_expr(context)?; - // Use the "function call" expression that was parsed, and put its values into a function declaration statement instead. if let Expr::FnCall(identifier, parameters) = primary { - let mut parameter_identifiers = Vec::new(); + if !prelude::is_prelude_func(&identifier) { + let expr = parse_expr(context)?; + advance(context); + let mut parameter_identifiers = Vec::new(); - // All the "arguments" are expected to be parsed as variables, - // since parameter definitions look the same as variable references. - // Extract these. - for parameter in parameters { - if let Expr::Var(parameter_identifier) = parameter { - parameter_identifiers.push(parameter_identifier); + // All the "arguments" are expected to be parsed as variables, + // since parameter definitions look the same as variable references. + // Extract these. + for parameter in parameters { + if let Expr::Var(parameter_identifier) = parameter { + parameter_identifiers.push(parameter_identifier); + } } + + let fn_decl = + Stmt::FnDecl(identifier.clone(), parameter_identifiers, Box::new(expr)); + + // Insert the function declaration into the symbol table during parsing + // so that the parser can find out if particular functions exist. + context.symbol_table.insert(fn_decl.clone()); + + return Ok(fn_decl); } - - let fn_decl = Stmt::FnDecl(identifier.clone(), parameter_identifiers, Box::new(expr)); - - // Insert the function declaration into the symbol table during parsing - // so that the parser can find out if particular functions exist. - context.symbol_table.insert(fn_decl.clone()); - - return Ok(fn_decl); } - - Err(CalcError::Unknown) - } else { - // It is a function call or eg. x(x + 3), not a function declaration. - // Redo the parsing for this specific part. - context.pos = began_at; - Ok(Stmt::Expr(Box::new(parse_expr(context)?))) } + + // It is a function call or eg. x(x + 3), not a function declaration. + // Redo the parsing for this specific part. + context.pos = began_at; + Ok(Stmt::Expr(Box::new(parse_expr(context)?))) } fn parse_var_decl_stmt(context: &mut Context) -> Result { @@ -201,7 +209,7 @@ fn parse_unit_decl_stmt(context: &mut Context) -> Result { let stmt_inv = Stmt::UnitDecl( base_unit.clone(), identifier.value.clone(), - Box::new(def.invert(&mut context.symbol_table)?), + Box::new(def.invert(&mut context.symbol_table, DECL_UNIT)?), ); let stmt = Stmt::UnitDecl(identifier.value, base_unit, Box::new(def)); @@ -212,17 +220,54 @@ fn parse_unit_decl_stmt(context: &mut Context) -> Result { } fn parse_expr(context: &mut Context) -> Result { - Ok(parse_to(context)?) + Ok(parse_equation(context)?) +} + +fn parse_equation(context: &mut Context) -> Result { + let left = parse_to(context)?; + + if match_token(context, TokenKind::Equals) { + advance(context); + let right = parse_to(context)?; + let var_name = if let Some(var_name) = &context.equation_variable { + var_name + } else { + return Err(CalcError::UnableToSolveEquation); + }; + + let inverted = if inverter::contains_var(&mut context.symbol_table, &left, var_name) { + left.invert_to_target(&mut context.symbol_table, right, var_name)? + } else { + right.invert_to_target(&mut context.symbol_table, left, var_name)? + }; + + // If the inverted expression still contains the variable, + // the equation solving failed. + if inverter::contains_var(&mut context.symbol_table, &inverted, var_name) { + return Err(CalcError::UnableToSolveEquation); + } + + context + .symbol_table + .insert(Stmt::VarDecl(var_name.into(), Box::new(inverted.clone()))); + return Ok(inverted); + } + + Ok(left) } fn parse_to(context: &mut Context) -> Result { let left = parse_sum(context)?; if match_token(context, TokenKind::ToKeyword) { - let op = advance(context).kind; + advance(context); let right = Expr::Var(advance(context).value.clone()); // Parse this as a variable for now. - return Ok(Expr::Binary(Box::new(left), op, Box::new(right))); + return Ok(Expr::Binary( + Box::new(left), + TokenKind::ToKeyword, + Box::new(right), + )); } Ok(left) @@ -367,7 +412,7 @@ fn parse_group_fn(context: &mut Context) -> Result { TokenKind::Pipe => "abs", TokenKind::OpenCeil => "ceil", TokenKind::OpenFloor => "floor", - _ => panic!("Unexpected parsing error."), + _ => unreachable!(), }; let expr = parse_expr(context)?; @@ -419,6 +464,15 @@ fn parse_identifier(context: &mut Context) -> Result { context.unit_decl_base_unit = Some(identifier.value); Ok(Expr::Var(DECL_UNIT.into())) } else { + if let Some(equation_var) = &context.equation_variable { + if &identifier.value == equation_var { + return Ok(Expr::Var(identifier.value)); + } + } else if context.contains_equal_sign { + context.equation_variable = Some(identifier.value.clone()); + return Ok(Expr::Var(identifier.value)); + } + let mut chars = identifier.value.chars(); let mut left = Expr::Var(chars.next().unwrap().to_string()); diff --git a/kalk/src/prelude.rs b/kalk/src/prelude.rs index e84e060..fd18eef 100644 --- a/kalk/src/prelude.rs +++ b/kalk/src/prelude.rs @@ -117,6 +117,13 @@ impl BinaryFuncInfo { } } +pub fn is_prelude_func(identifier: &str) -> bool { + identifier == "sum" + || identifier == "Σ" + || UNARY_FUNCS.contains_key(identifier) + || BINARY_FUNCS.contains_key(identifier) +} + pub fn call_unary_func( context: &mut interpreter::Context, name: &str, diff --git a/kalk/src/symbol_table.rs b/kalk/src/symbol_table.rs index abd27ab..7bb1e7e 100644 --- a/kalk/src/symbol_table.rs +++ b/kalk/src/symbol_table.rs @@ -74,10 +74,7 @@ impl SymbolTable { } pub fn contains_fn(&self, identifier: &str) -> bool { - identifier == "sum" - || identifier == "Σ" - || prelude::UNARY_FUNCS.contains_key(identifier) - || prelude::BINARY_FUNCS.contains_key(identifier) + prelude::is_prelude_func(identifier) || self.hashmap.contains_key(&format!("fn.{}", identifier)) } } diff --git a/kalk_cli/src/output.rs b/kalk_cli/src/output.rs index dc7509d..dd685ea 100644 --- a/kalk_cli/src/output.rs +++ b/kalk_cli/src/output.rs @@ -42,6 +42,7 @@ fn print_calc_err(err: CalcError) { UndefinedFn(name) => format!("Undefined function: '{}'.", name), UndefinedVar(name) => format!("Undefined variable: '{}'.", name), UnableToParseExpression => format!("Unable to parse expression."), + UnableToSolveEquation => format!("Unable to solve equation."), Unknown => format!("Unknown error."), }); }