From 4d6ef5e8d7b7a8493819907d16b96411dd750827 Mon Sep 17 00:00:00 2001 From: bakk Date: Sun, 24 Apr 2022 21:23:29 +0200 Subject: [PATCH] Numerical equation solving using Newton's method --- kalk/src/analysis.rs | 158 +++++++++---------------- kalk/src/ast.rs | 1 + kalk/src/calculation_result.rs | 2 +- kalk/src/integration_testing.rs | 1 + kalk/src/interpreter.rs | 25 +++- kalk/src/inverter.rs | 2 + kalk/src/lib.rs | 2 +- kalk/src/{calculus.rs => numerical.rs} | 83 ++++++++++++- kalk/src/parser.rs | 66 ++--------- tests/equations.kalker | 1 + tests/functions.kalker | 2 +- 11 files changed, 181 insertions(+), 162 deletions(-) rename kalk/src/{calculus.rs => numerical.rs} (78%) create mode 100644 tests/equations.kalker diff --git a/kalk/src/analysis.rs b/kalk/src/analysis.rs index 242cdff..1be271d 100644 --- a/kalk/src/analysis.rs +++ b/kalk/src/analysis.rs @@ -80,11 +80,14 @@ pub(crate) fn analyse_stmt( fn analyse_stmt_expr(context: &mut Context, value: Expr) -> Result { Ok( if let Expr::Binary(left, TokenKind::Equals, right) = value { + if let Some((identifier, parameters)) = is_fn_decl(&*left) { + return build_fn_decl_from_scratch(context, identifier, parameters, *right); + } + match *left { - Expr::Binary(identifier_expr, TokenKind::Star, parameter_expr) => { - build_fn_decl_from_scratch(context, *identifier_expr, *parameter_expr, *right)? - } - Expr::FnCall(identifier, arguments) => { + Expr::FnCall(identifier, arguments) + if !prelude::is_prelude_func(&identifier.full_name) => + { // First loop through with a reference // to arguments, to be able to back-track if // one of the arguments can't be made into a parameter. @@ -138,11 +141,10 @@ fn analyse_stmt_expr(context: &mut Context, value: Expr) -> Result Stmt::Expr(Box::new(Expr::Binary( - Box::new(analyse_expr(context, *left)?), - TokenKind::Equals, - right, - ))), + _ => Stmt::Expr(Box::new(analyse_expr( + context, + Expr::Binary(left, TokenKind::Equals, right), + )?)), } } else { Stmt::Expr(Box::new(analyse_expr(context, value)?)) @@ -150,83 +152,56 @@ fn analyse_stmt_expr(context: &mut Context, value: Expr) -> Result Option<(Identifier, Vec)> { + if let Expr::Binary(left, TokenKind::Star, right) = &*expr { + let identifier = if let Expr::Var(identifier) = &**left { + identifier + } else { + return None; + }; + + let exprs = match &**right { + Expr::Vector(exprs) => exprs.iter().collect(), + Expr::Group(expr) => vec![&**expr], + _ => return None, + }; + + let mut parameters = Vec::new(); + for expr in exprs { + if let Expr::Var(argument_identifier) = expr { + parameters.push(format!( + "{}-{}", + identifier.pure_name, argument_identifier.pure_name + )); + } + } + + if !prelude::is_prelude_func(&identifier.full_name) { + return Some((identifier.clone(), parameters)); + } + } + + None +} + fn build_fn_decl_from_scratch( context: &mut Context, - identifier_expr: Expr, - parameter_expr: Expr, + identifier: Identifier, + parameters: Vec, right: Expr, ) -> Result { - Ok(match identifier_expr { - Expr::Var(identifier) if !prelude::is_prelude_func(&identifier.full_name) => { - // Check if all the expressions in the parameter_expr are - // variables. If not, it can't be turned into a function declaration. - let all_are_vars = match ¶meter_expr { - Expr::Vector(exprs) => exprs.iter().any(|x| matches!(x, Expr::Var(_))), - Expr::Group(expr) => { - matches!(&**expr, Expr::Var(_)) - } - _ => false, - }; + context.current_function_name = Some(identifier.pure_name.clone()); + context.current_function_parameters = Some(parameters.clone()); + let fn_decl = Stmt::FnDecl( + identifier, + parameters, + Box::new(analyse_expr(context, right)?), + ); + context.symbol_table.insert(fn_decl.clone()); + context.current_function_name = None; + context.current_function_parameters = None; - if !all_are_vars { - // Analyse it as a function call instead - return Ok(Stmt::Expr(Box::new(analyse_expr( - context, - Expr::Binary( - Box::new(Expr::Binary( - Box::new(Expr::Var(identifier)), - TokenKind::Star, - Box::new(parameter_expr), - )), - TokenKind::Equals, - Box::new(right), - ), - )?))); - } - - let exprs = match parameter_expr { - Expr::Vector(exprs) => exprs, - Expr::Group(expr) => vec![*expr], - _ => unreachable!(), - }; - - let mut parameters = Vec::new(); - for expr in exprs { - if let Expr::Var(argument_identifier) = expr { - parameters.push(format!( - "{}-{}", - identifier.pure_name, argument_identifier.pure_name - )); - } - } - - context.current_function_name = Some(identifier.pure_name.clone()); - context.current_function_parameters = Some(parameters.clone()); - let fn_decl = Stmt::FnDecl( - identifier, - parameters, - Box::new(analyse_expr(context, right)?), - ); - context.symbol_table.insert(fn_decl.clone()); - context.current_function_name = None; - context.current_function_parameters = None; - - fn_decl - } - _ => { - let new_binary = Expr::Binary( - Box::new(Expr::Binary( - Box::new(identifier_expr), - TokenKind::Star, - Box::new(parameter_expr), - )), - TokenKind::Equals, - Box::new(right), - ); - - Stmt::Expr(Box::new(analyse_expr(context, new_binary)?)) - } - }) + Ok(fn_decl) } fn analyse_expr(context: &mut Context, expr: Expr) -> Result { @@ -283,6 +258,7 @@ fn analyse_expr(context: &mut Context, expr: Expr) -> Result { Expr::Indexer(Box::new(analyse_expr(context, *value)?), analysed_indexes) } Expr::Comprehension(left, right, vars) => Expr::Comprehension(left, right, vars), + Expr::Equation(left, right, identifier) => Expr::Equation(left, right, identifier), }) } @@ -325,26 +301,10 @@ fn analyse_binary( return result; }; - - let inverted = if inverter::contains_var(context.symbol_table, &left, var_name) { - left.invert_to_target(context.symbol_table, right, var_name)? - } else { - right.invert_to_target(context.symbol_table, left, var_name)? - }; - - // If the inverted expression still contains the variable, - // the equation solving failed. - if inverter::contains_var(context.symbol_table, &inverted, var_name) { - return Err(KalkError::UnableToSolveEquation); - } - - context.symbol_table.insert(Stmt::VarDecl( - Identifier::from_full_name(var_name), - Box::new(inverted.clone()), - )); + let identifier = Identifier::from_full_name(var_name); context.equation_variable = None; - Ok(inverted) + Ok(Expr::Equation(Box::new(left), Box::new(right), identifier)) } (Expr::Var(_), TokenKind::Star, _) => { if let Expr::Var(identifier) = left { diff --git a/kalk/src/ast.rs b/kalk/src/ast.rs index 926664a..2e13a42 100644 --- a/kalk/src/ast.rs +++ b/kalk/src/ast.rs @@ -25,6 +25,7 @@ pub enum Expr { Matrix(Vec>), Indexer(Box, Vec), Comprehension(Box, Vec, Vec), + Equation(Box, Box, Identifier), } #[derive(Debug, Clone, PartialEq)] diff --git a/kalk/src/calculation_result.rs b/kalk/src/calculation_result.rs index e5d1ce9..81650f2 100644 --- a/kalk/src/calculation_result.rs +++ b/kalk/src/calculation_result.rs @@ -76,4 +76,4 @@ impl std::fmt::Display for CalculationResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.value) } -} \ No newline at end of file +} diff --git a/kalk/src/integration_testing.rs b/kalk/src/integration_testing.rs index 4cc2fdb..57ca9c9 100644 --- a/kalk/src/integration_testing.rs +++ b/kalk/src/integration_testing.rs @@ -44,6 +44,7 @@ mod tests { #[test_case("basics")] #[test_case("comparisons")] #[test_case("comprehensions")] + #[test_case("equations")] #[test_case("derivation")] #[test_case("functions")] #[test_case("groups")] diff --git a/kalk/src/interpreter.rs b/kalk/src/interpreter.rs index 276aa51..ad739fa 100644 --- a/kalk/src/interpreter.rs +++ b/kalk/src/interpreter.rs @@ -6,7 +6,7 @@ use crate::kalk_value::KalkValue; use crate::lexer::TokenKind; use crate::parser::DECL_UNIT; use crate::symbol_table::SymbolTable; -use crate::{as_number_or_zero, calculus}; +use crate::{as_number_or_zero, numerical}; use crate::{float, prelude}; pub struct Context<'a> { @@ -135,6 +135,7 @@ pub(crate) fn eval_expr( Expr::Comprehension(left, conditions, vars) => Ok(KalkValue::Vector(eval_comprehension( context, left, conditions, vars, )?)), + Expr::Equation(left, right, identifier) => eval_equation(context, left, right, identifier), } } @@ -355,13 +356,13 @@ pub(crate) fn eval_fn_call_expr( } "integrate" => { return match expressions.len() { - 3 => calculus::integrate_with_unknown_variable( + 3 => numerical::integrate_with_unknown_variable( context, &expressions[0], &expressions[1], &expressions[2], ), - 4 => calculus::integrate( + 4 => numerical::integrate( context, &expressions[0], &expressions[1], @@ -406,7 +407,7 @@ pub(crate) fn eval_fn_call_expr( 1 => { let x = eval_expr(context, &expressions[0], None)?; if identifier.prime_count > 0 { - return calculus::derive_func(context, identifier, x); + return numerical::derive_func(context, identifier, x); } else { prelude::call_unary_func( context, @@ -518,7 +519,7 @@ pub(crate) fn eval_fn_call_expr( )?)), ); - // Don't set these values just yet, since + // Don't set these values just yet, // to avoid affecting the value of arguments // during recursion. new_argument_values.push((argument, var_decl)); @@ -771,6 +772,20 @@ fn eval_comprehension( Ok(values) } +fn eval_equation( + context: &mut Context, + left: &Expr, + right: &Expr, + unknown_var: &Identifier, +) -> Result { + let expr = Expr::Binary( + Box::new(left.clone()), + TokenKind::Minus, + Box::new(right.clone()), + ); + numerical::find_root(context, &expr, &unknown_var.full_name) +} + #[cfg(test)] mod tests { use super::*; diff --git a/kalk/src/inverter.rs b/kalk/src/inverter.rs index 8d7a391..26d92e9 100644 --- a/kalk/src/inverter.rs +++ b/kalk/src/inverter.rs @@ -93,6 +93,7 @@ fn invert( Expr::Comprehension(_, _, _) => { Err(KalkError::UnableToInvert(String::from("Comprehension"))) } + Expr::Equation(_, _, _) => Err(KalkError::UnableToInvert(String::from("Equation"))), } } @@ -400,6 +401,7 @@ pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) -> .any(|row| row.iter().any(|x| contains_var(symbol_table, x, var_name))), Expr::Indexer(_, _) => false, Expr::Comprehension(_, _, _) => false, + Expr::Equation(_, _, _) => false, } } diff --git a/kalk/src/lib.rs b/kalk/src/lib.rs index 0f41702..e1bd49e 100644 --- a/kalk/src/lib.rs +++ b/kalk/src/lib.rs @@ -4,13 +4,13 @@ mod analysis; pub mod ast; pub mod calculation_result; -mod calculus; mod errors; mod integration_testing; mod interpreter; mod inverter; pub mod kalk_value; mod lexer; +mod numerical; pub mod parser; mod prelude; mod radix; diff --git a/kalk/src/calculus.rs b/kalk/src/numerical.rs similarity index 78% rename from kalk/src/calculus.rs rename to kalk/src/numerical.rs index 3f5dc81..f75b464 100644 --- a/kalk/src/calculus.rs +++ b/kalk/src/numerical.rs @@ -140,15 +140,83 @@ fn simpsons_rule( )) } +pub fn find_root( + context: &mut interpreter::Context, + expr: &Expr, + var_name: &str, +) -> Result { + const FN_NAME: &str = "tmp."; + let f = Stmt::FnDecl( + Identifier::from_full_name(FN_NAME), + vec![var_name.into()], + Box::new(expr.clone()), + ); + context.symbol_table.set(f); + let mut approx = KalkValue::from(1f64); + for _ in 0..100 { + let (new_approx, done) = newton_method(context, approx, &Identifier::from_full_name(FN_NAME))?; + approx = new_approx; + if done { + break; + } + } + + // Confirm that the approximation is correct + let (test_real, test_imaginary) = interpreter::eval_fn_call_expr( + context, + &Identifier::from_full_name(FN_NAME), + &[crate::ast::build_literal_ast(&approx)], + None, + )? + .values(); + + context.symbol_table.get_and_remove_var(var_name); + + if test_real.is_nan() || test_real.abs() > 0.0001f64 || test_imaginary.abs() > 0.0001f64 { + return Err(KalkError::UnableToSolveEquation); + } + + Ok(approx) +} + +fn newton_method( + context: &mut interpreter::Context, + initial: KalkValue, + fn_name: &Identifier, +) -> Result<(KalkValue, bool), KalkError> { + let f = interpreter::eval_fn_call_expr( + context, + fn_name, + &[crate::ast::build_literal_ast(&initial)], + None, + )?; + + // If it ends up solving the equation early, abort + const PRECISION: f64 = 0.0000001f64; + match f { + KalkValue::Number(x, y, _) + if x < PRECISION && x > -PRECISION && y < PRECISION && y > -PRECISION => + { + return Ok((initial, true)); + } + _ => (), + } + + let f_prime_name = Identifier::from_name_and_primes(&fn_name.pure_name, 1); + let f_prime = derive_func(context, &f_prime_name, initial.clone())?; + + Ok((initial.sub_without_unit(&f.div_without_unit(&f_prime)?)?, false)) +} + #[cfg(test)] mod tests { use crate::ast; - use crate::calculus::Identifier; - use crate::calculus::Stmt; use crate::float; use crate::interpreter; use crate::kalk_value::KalkValue; use crate::lexer::TokenKind::*; + use crate::numerical::Identifier; + use crate::numerical::Stmt; use crate::symbol_table::SymbolTable; use crate::test_helpers::*; @@ -282,4 +350,15 @@ mod tests { assert!(cmp(result.to_f64(), -12f64)); assert!(cmp(result.imaginary_to_f64(), -5.5f64)); } + + #[test] + fn test_find_root() { + let mut symbol_table = SymbolTable::new(); + let mut context = get_context(&mut symbol_table); + let ast = &*binary(binary(var("x"), Power, literal(3f64)), Plus, literal(3f64)); + let result = super::find_root(&mut context, ast, "x").unwrap(); + + assert!(cmp(result.to_f64(), -1.4422495709)); + assert!(!result.has_imaginary()); + } } diff --git a/kalk/src/parser.rs b/kalk/src/parser.rs index dfbfb3c..0ce8438 100644 --- a/kalk/src/parser.rs +++ b/kalk/src/parser.rs @@ -303,55 +303,17 @@ fn parse_comparison(context: &mut Context) -> Result { let op = peek(context).kind; advance(context); - // If it's potentially a function declaration, run it through - // the analysis phase to ensure it gets added to the symbol - // table before parsing the right side. This is necessary for - // recursion to work. - if let (TokenKind::Equals, Expr::Binary(_, TokenKind::Star, _)) = (TokenKind::Equals, &left) - { - let analysed = analysis::analyse_stmt( - context.symbol_table.get_mut(), - Stmt::Expr(Box::new(Expr::Binary( - Box::new(left), - op, - Box::new(Expr::Literal(0f64)), - ))), - )?; + let is_fn_decl = if let Some((identifier, parameters)) = analysis::is_fn_decl(&left) { + context.symbol_table.get_mut().set(Stmt::FnDecl( + identifier, + parameters, + Box::new(Expr::Literal(0f64)), + )); - left = match analysed { - // Reconstruct function declarations into what they were originally parsed as - Stmt::FnDecl(identifier, parameters, _) => { - let mut parameter_vars: Vec = parameters - .into_iter() - .map(|x| { - Expr::Var(Identifier::from_full_name( - // Parameters will come back as eg. f-x, - // therefore the function name needs to be removed - &x[identifier.full_name.len() + 1..], - )) - }) - .collect(); - - Expr::Binary( - Box::new(Expr::Var(identifier)), - TokenKind::Star, - Box::new(if parameter_vars.len() > 1 { - Expr::Vector(parameter_vars) - } else { - Expr::Group(Box::new(parameter_vars.pop().unwrap())) - }), - ) - } - Stmt::Expr(analysed_expr) => { - if let Expr::Binary(analysed_left, TokenKind::Equals, _) = *analysed_expr { - *analysed_left - } else { - unreachable!() - } - } - _ => unreachable!(), - }; - } + true + } else { + false + }; let right = if op == TokenKind::Equals && match_token(context, TokenKind::OpenBrace) { parse_piecewise(context)? @@ -362,16 +324,14 @@ fn parse_comparison(context: &mut Context) -> Result { left = match right { Expr::Binary( inner_left, - inner_op - @ - (TokenKind::Equals + inner_op @ (TokenKind::Equals | TokenKind::NotEquals | TokenKind::GreaterThan | TokenKind::LessThan | TokenKind::GreaterOrEquals | TokenKind::LessOrEquals), inner_right, - ) => Expr::Binary( + ) if !is_fn_decl => Expr::Binary( Box::new(Expr::Binary( Box::new(left), op, @@ -657,7 +617,7 @@ fn parse_identifier(context: &mut Context) -> Result { .contains_fn(&identifier.pure_name) { // Function call - let mut arguments = match parse_vector(context)? { + let mut arguments = match parse_primary(context)? { Expr::Vector(arguments) => arguments, Expr::Group(argument) => vec![*argument], argument => vec![argument], diff --git a/tests/equations.kalker b/tests/equations.kalker new file mode 100644 index 0000000..9714d61 --- /dev/null +++ b/tests/equations.kalker @@ -0,0 +1 @@ +(3x^3 - 2x = x^2 + 2) = 1.270776326 \ No newline at end of file diff --git a/tests/functions.kalker b/tests/functions.kalker index 48676a1..6628546 100644 --- a/tests/functions.kalker +++ b/tests/functions.kalker @@ -2,4 +2,4 @@ x = 3 f(x) = 2*x g(x, y) = 2*x*y -f(x) = 6 and fx = 6 and x = 3 and g(x, x + 1) = 24 \ No newline at end of file +f(x) = 6 and fx = 6 and x = 3 and g(x, x + 1) = 24 and sqrt4 = 2 \ No newline at end of file