Numerical equation solving using Newton's method

This commit is contained in:
bakk 2022-04-24 21:23:29 +02:00
parent 8264f84ba7
commit 4d6ef5e8d7
11 changed files with 181 additions and 162 deletions

View File

@ -80,11 +80,14 @@ pub(crate) fn analyse_stmt(
fn analyse_stmt_expr(context: &mut Context, value: Expr) -> Result<Stmt, KalkError> {
Ok(
if let Expr::Binary(left, TokenKind::Equals, right) = value {
match *left {
Expr::Binary(identifier_expr, TokenKind::Star, parameter_expr) => {
build_fn_decl_from_scratch(context, *identifier_expr, *parameter_expr, *right)?
if let Some((identifier, parameters)) = is_fn_decl(&*left) {
return build_fn_decl_from_scratch(context, identifier, parameters, *right);
}
Expr::FnCall(identifier, arguments) => {
match *left {
Expr::FnCall(identifier, arguments)
if !prelude::is_prelude_func(&identifier.full_name) =>
{
// First loop through with a reference
// to arguments, to be able to back-track if
// one of the arguments can't be made into a parameter.
@ -138,11 +141,10 @@ fn analyse_stmt_expr(context: &mut Context, value: Expr) -> Result<Stmt, KalkErr
result
}
_ => Stmt::Expr(Box::new(Expr::Binary(
Box::new(analyse_expr(context, *left)?),
TokenKind::Equals,
right,
))),
_ => Stmt::Expr(Box::new(analyse_expr(
context,
Expr::Binary(left, TokenKind::Equals, right),
)?)),
}
} else {
Stmt::Expr(Box::new(analyse_expr(context, value)?))
@ -150,44 +152,18 @@ fn analyse_stmt_expr(context: &mut Context, value: Expr) -> Result<Stmt, KalkErr
)
}
fn build_fn_decl_from_scratch(
context: &mut Context,
identifier_expr: Expr,
parameter_expr: Expr,
right: Expr,
) -> Result<Stmt, KalkError> {
Ok(match identifier_expr {
Expr::Var(identifier) if !prelude::is_prelude_func(&identifier.full_name) => {
// Check if all the expressions in the parameter_expr are
// variables. If not, it can't be turned into a function declaration.
let all_are_vars = match &parameter_expr {
Expr::Vector(exprs) => exprs.iter().any(|x| matches!(x, Expr::Var(_))),
Expr::Group(expr) => {
matches!(&**expr, Expr::Var(_))
}
_ => false,
pub fn is_fn_decl(expr: &Expr) -> Option<(Identifier, Vec<String>)> {
if let Expr::Binary(left, TokenKind::Star, right) = &*expr {
let identifier = if let Expr::Var(identifier) = &**left {
identifier
} else {
return None;
};
if !all_are_vars {
// Analyse it as a function call instead
return Ok(Stmt::Expr(Box::new(analyse_expr(
context,
Expr::Binary(
Box::new(Expr::Binary(
Box::new(Expr::Var(identifier)),
TokenKind::Star,
Box::new(parameter_expr),
)),
TokenKind::Equals,
Box::new(right),
),
)?)));
}
let exprs = match parameter_expr {
Expr::Vector(exprs) => exprs,
Expr::Group(expr) => vec![*expr],
_ => unreachable!(),
let exprs = match &**right {
Expr::Vector(exprs) => exprs.iter().collect(),
Expr::Group(expr) => vec![&**expr],
_ => return None,
};
let mut parameters = Vec::new();
@ -200,6 +176,20 @@ fn build_fn_decl_from_scratch(
}
}
if !prelude::is_prelude_func(&identifier.full_name) {
return Some((identifier.clone(), parameters));
}
}
None
}
fn build_fn_decl_from_scratch(
context: &mut Context,
identifier: Identifier,
parameters: Vec<String>,
right: Expr,
) -> Result<Stmt, KalkError> {
context.current_function_name = Some(identifier.pure_name.clone());
context.current_function_parameters = Some(parameters.clone());
let fn_decl = Stmt::FnDecl(
@ -211,22 +201,7 @@ fn build_fn_decl_from_scratch(
context.current_function_name = None;
context.current_function_parameters = None;
fn_decl
}
_ => {
let new_binary = Expr::Binary(
Box::new(Expr::Binary(
Box::new(identifier_expr),
TokenKind::Star,
Box::new(parameter_expr),
)),
TokenKind::Equals,
Box::new(right),
);
Stmt::Expr(Box::new(analyse_expr(context, new_binary)?))
}
})
Ok(fn_decl)
}
fn analyse_expr(context: &mut Context, expr: Expr) -> Result<Expr, KalkError> {
@ -283,6 +258,7 @@ fn analyse_expr(context: &mut Context, expr: Expr) -> Result<Expr, KalkError> {
Expr::Indexer(Box::new(analyse_expr(context, *value)?), analysed_indexes)
}
Expr::Comprehension(left, right, vars) => Expr::Comprehension(left, right, vars),
Expr::Equation(left, right, identifier) => Expr::Equation(left, right, identifier),
})
}
@ -325,26 +301,10 @@ fn analyse_binary(
return result;
};
let inverted = if inverter::contains_var(context.symbol_table, &left, var_name) {
left.invert_to_target(context.symbol_table, right, var_name)?
} else {
right.invert_to_target(context.symbol_table, left, var_name)?
};
// If the inverted expression still contains the variable,
// the equation solving failed.
if inverter::contains_var(context.symbol_table, &inverted, var_name) {
return Err(KalkError::UnableToSolveEquation);
}
context.symbol_table.insert(Stmt::VarDecl(
Identifier::from_full_name(var_name),
Box::new(inverted.clone()),
));
let identifier = Identifier::from_full_name(var_name);
context.equation_variable = None;
Ok(inverted)
Ok(Expr::Equation(Box::new(left), Box::new(right), identifier))
}
(Expr::Var(_), TokenKind::Star, _) => {
if let Expr::Var(identifier) = left {

View File

@ -25,6 +25,7 @@ pub enum Expr {
Matrix(Vec<Vec<Expr>>),
Indexer(Box<Expr>, Vec<Expr>),
Comprehension(Box<Expr>, Vec<Expr>, Vec<RangedVar>),
Equation(Box<Expr>, Box<Expr>, Identifier),
}
#[derive(Debug, Clone, PartialEq)]

View File

@ -44,6 +44,7 @@ mod tests {
#[test_case("basics")]
#[test_case("comparisons")]
#[test_case("comprehensions")]
#[test_case("equations")]
#[test_case("derivation")]
#[test_case("functions")]
#[test_case("groups")]

View File

@ -6,7 +6,7 @@ use crate::kalk_value::KalkValue;
use crate::lexer::TokenKind;
use crate::parser::DECL_UNIT;
use crate::symbol_table::SymbolTable;
use crate::{as_number_or_zero, calculus};
use crate::{as_number_or_zero, numerical};
use crate::{float, prelude};
pub struct Context<'a> {
@ -135,6 +135,7 @@ pub(crate) fn eval_expr(
Expr::Comprehension(left, conditions, vars) => Ok(KalkValue::Vector(eval_comprehension(
context, left, conditions, vars,
)?)),
Expr::Equation(left, right, identifier) => eval_equation(context, left, right, identifier),
}
}
@ -355,13 +356,13 @@ pub(crate) fn eval_fn_call_expr(
}
"integrate" => {
return match expressions.len() {
3 => calculus::integrate_with_unknown_variable(
3 => numerical::integrate_with_unknown_variable(
context,
&expressions[0],
&expressions[1],
&expressions[2],
),
4 => calculus::integrate(
4 => numerical::integrate(
context,
&expressions[0],
&expressions[1],
@ -406,7 +407,7 @@ pub(crate) fn eval_fn_call_expr(
1 => {
let x = eval_expr(context, &expressions[0], None)?;
if identifier.prime_count > 0 {
return calculus::derive_func(context, identifier, x);
return numerical::derive_func(context, identifier, x);
} else {
prelude::call_unary_func(
context,
@ -518,7 +519,7 @@ pub(crate) fn eval_fn_call_expr(
)?)),
);
// Don't set these values just yet, since
// Don't set these values just yet,
// to avoid affecting the value of arguments
// during recursion.
new_argument_values.push((argument, var_decl));
@ -771,6 +772,20 @@ fn eval_comprehension(
Ok(values)
}
fn eval_equation(
context: &mut Context,
left: &Expr,
right: &Expr,
unknown_var: &Identifier,
) -> Result<KalkValue, KalkError> {
let expr = Expr::Binary(
Box::new(left.clone()),
TokenKind::Minus,
Box::new(right.clone()),
);
numerical::find_root(context, &expr, &unknown_var.full_name)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -93,6 +93,7 @@ fn invert(
Expr::Comprehension(_, _, _) => {
Err(KalkError::UnableToInvert(String::from("Comprehension")))
}
Expr::Equation(_, _, _) => Err(KalkError::UnableToInvert(String::from("Equation"))),
}
}
@ -400,6 +401,7 @@ pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) ->
.any(|row| row.iter().any(|x| contains_var(symbol_table, x, var_name))),
Expr::Indexer(_, _) => false,
Expr::Comprehension(_, _, _) => false,
Expr::Equation(_, _, _) => false,
}
}

View File

@ -4,13 +4,13 @@
mod analysis;
pub mod ast;
pub mod calculation_result;
mod calculus;
mod errors;
mod integration_testing;
mod interpreter;
mod inverter;
pub mod kalk_value;
mod lexer;
mod numerical;
pub mod parser;
mod prelude;
mod radix;

View File

@ -140,15 +140,83 @@ fn simpsons_rule(
))
}
pub fn find_root(
context: &mut interpreter::Context,
expr: &Expr,
var_name: &str,
) -> Result<KalkValue, KalkError> {
const FN_NAME: &str = "tmp.";
let f = Stmt::FnDecl(
Identifier::from_full_name(FN_NAME),
vec![var_name.into()],
Box::new(expr.clone()),
);
context.symbol_table.set(f);
let mut approx = KalkValue::from(1f64);
for _ in 0..100 {
let (new_approx, done) = newton_method(context, approx, &Identifier::from_full_name(FN_NAME))?;
approx = new_approx;
if done {
break;
}
}
// Confirm that the approximation is correct
let (test_real, test_imaginary) = interpreter::eval_fn_call_expr(
context,
&Identifier::from_full_name(FN_NAME),
&[crate::ast::build_literal_ast(&approx)],
None,
)?
.values();
context.symbol_table.get_and_remove_var(var_name);
if test_real.is_nan() || test_real.abs() > 0.0001f64 || test_imaginary.abs() > 0.0001f64 {
return Err(KalkError::UnableToSolveEquation);
}
Ok(approx)
}
fn newton_method(
context: &mut interpreter::Context,
initial: KalkValue,
fn_name: &Identifier,
) -> Result<(KalkValue, bool), KalkError> {
let f = interpreter::eval_fn_call_expr(
context,
fn_name,
&[crate::ast::build_literal_ast(&initial)],
None,
)?;
// If it ends up solving the equation early, abort
const PRECISION: f64 = 0.0000001f64;
match f {
KalkValue::Number(x, y, _)
if x < PRECISION && x > -PRECISION && y < PRECISION && y > -PRECISION =>
{
return Ok((initial, true));
}
_ => (),
}
let f_prime_name = Identifier::from_name_and_primes(&fn_name.pure_name, 1);
let f_prime = derive_func(context, &f_prime_name, initial.clone())?;
Ok((initial.sub_without_unit(&f.div_without_unit(&f_prime)?)?, false))
}
#[cfg(test)]
mod tests {
use crate::ast;
use crate::calculus::Identifier;
use crate::calculus::Stmt;
use crate::float;
use crate::interpreter;
use crate::kalk_value::KalkValue;
use crate::lexer::TokenKind::*;
use crate::numerical::Identifier;
use crate::numerical::Stmt;
use crate::symbol_table::SymbolTable;
use crate::test_helpers::*;
@ -282,4 +350,15 @@ mod tests {
assert!(cmp(result.to_f64(), -12f64));
assert!(cmp(result.imaginary_to_f64(), -5.5f64));
}
#[test]
fn test_find_root() {
let mut symbol_table = SymbolTable::new();
let mut context = get_context(&mut symbol_table);
let ast = &*binary(binary(var("x"), Power, literal(3f64)), Plus, literal(3f64));
let result = super::find_root(&mut context, ast, "x").unwrap();
assert!(cmp(result.to_f64(), -1.4422495709));
assert!(!result.has_imaginary());
}
}

View File

@ -303,55 +303,17 @@ fn parse_comparison(context: &mut Context) -> Result<Expr, KalkError> {
let op = peek(context).kind;
advance(context);
// If it's potentially a function declaration, run it through
// the analysis phase to ensure it gets added to the symbol
// table before parsing the right side. This is necessary for
// recursion to work.
if let (TokenKind::Equals, Expr::Binary(_, TokenKind::Star, _)) = (TokenKind::Equals, &left)
{
let analysed = analysis::analyse_stmt(
context.symbol_table.get_mut(),
Stmt::Expr(Box::new(Expr::Binary(
Box::new(left),
op,
let is_fn_decl = if let Some((identifier, parameters)) = analysis::is_fn_decl(&left) {
context.symbol_table.get_mut().set(Stmt::FnDecl(
identifier,
parameters,
Box::new(Expr::Literal(0f64)),
))),
)?;
));
left = match analysed {
// Reconstruct function declarations into what they were originally parsed as
Stmt::FnDecl(identifier, parameters, _) => {
let mut parameter_vars: Vec<Expr> = parameters
.into_iter()
.map(|x| {
Expr::Var(Identifier::from_full_name(
// Parameters will come back as eg. f-x,
// therefore the function name needs to be removed
&x[identifier.full_name.len() + 1..],
))
})
.collect();
Expr::Binary(
Box::new(Expr::Var(identifier)),
TokenKind::Star,
Box::new(if parameter_vars.len() > 1 {
Expr::Vector(parameter_vars)
true
} else {
Expr::Group(Box::new(parameter_vars.pop().unwrap()))
}),
)
}
Stmt::Expr(analysed_expr) => {
if let Expr::Binary(analysed_left, TokenKind::Equals, _) = *analysed_expr {
*analysed_left
} else {
unreachable!()
}
}
_ => unreachable!(),
false
};
}
let right = if op == TokenKind::Equals && match_token(context, TokenKind::OpenBrace) {
parse_piecewise(context)?
@ -362,16 +324,14 @@ fn parse_comparison(context: &mut Context) -> Result<Expr, KalkError> {
left = match right {
Expr::Binary(
inner_left,
inner_op
@
(TokenKind::Equals
inner_op @ (TokenKind::Equals
| TokenKind::NotEquals
| TokenKind::GreaterThan
| TokenKind::LessThan
| TokenKind::GreaterOrEquals
| TokenKind::LessOrEquals),
inner_right,
) => Expr::Binary(
) if !is_fn_decl => Expr::Binary(
Box::new(Expr::Binary(
Box::new(left),
op,
@ -657,7 +617,7 @@ fn parse_identifier(context: &mut Context) -> Result<Expr, KalkError> {
.contains_fn(&identifier.pure_name)
{
// Function call
let mut arguments = match parse_vector(context)? {
let mut arguments = match parse_primary(context)? {
Expr::Vector(arguments) => arguments,
Expr::Group(argument) => vec![*argument],
argument => vec![argument],

1
tests/equations.kalker Normal file
View File

@ -0,0 +1 @@
(3x^3 - 2x = x^2 + 2) = 1.270776326

View File

@ -2,4 +2,4 @@ x = 3
f(x) = 2*x
g(x, y) = 2*x*y
f(x) = 6 and fx = 6 and x = 3 and g(x, x + 1) = 24
f(x) = 6 and fx = 6 and x = 3 and g(x, x + 1) = 24 and sqrt4 = 2