diff --git a/kalk/src/analysis.rs b/kalk/src/analysis.rs index 46d128b..1697fc4 100644 --- a/kalk/src/analysis.rs +++ b/kalk/src/analysis.rs @@ -207,12 +207,21 @@ fn analyse_binary<'a>( // Equation context.in_equation = true; let left = analyse_expr(context, left)?; + + // If it has already been set to false manually somewhere else, + // abort and analyse as a comparison instead. + if context.in_equation == false { + context.in_conditional = true; + return analyse_binary(context, left, op, right); + } + context.in_equation = false; let var_name = if let Some(var_name) = &context.equation_variable { var_name } else { - return Err(CalcError::UnableToSolveEquation); + context.in_conditional = true; + return analyse_binary(context, left, op, right); }; let inverted = if inverter::contains_var(&mut context.symbol_table, &left, var_name) { @@ -231,6 +240,7 @@ fn analyse_binary<'a>( Identifier::from_full_name(var_name), Box::new(inverted.clone()), )); + context.equation_variable = None; return Ok(inverted); } @@ -378,6 +388,11 @@ fn build_fn_call( context.in_integral = true; } + // Don't perform equation solving on special functions + if is_integral || identifier.full_name == "sum" || identifier.full_name == "prod" { + context.in_equation = false; + } + let arguments = match adjacent_expr { Expr::Vector(arguments) => { let mut new_arguments = Vec::new(); @@ -484,7 +499,7 @@ fn build_split_up_vars( Ok(left) } -fn build_var(context: &Context, name: &str) -> Expr { +fn build_var(context: &mut Context, name: &str) -> Expr { if let (Some(function_name), Some(params)) = ( context.current_function_name.as_ref(), context.current_function_parameters.as_ref(), @@ -494,5 +509,10 @@ fn build_var(context: &Context, name: &str) -> Expr { return Expr::Var(identifier); } } - return Expr::Var(Identifier::from_full_name(name)); + + if context.in_equation && !context.symbol_table.contains_var(name) { + context.equation_variable = Some(name.to_string()); + } + + Expr::Var(Identifier::from_full_name(name)) }