diff --git a/kalk/src/analysis.rs b/kalk/src/analysis.rs index fd3e465..7b73bac 100644 --- a/kalk/src/analysis.rs +++ b/kalk/src/analysis.rs @@ -14,7 +14,7 @@ pub(crate) struct Context<'a> { equation_variable: Option, in_integral: bool, in_sum_prod: bool, - sum_variable_name: Option, + sum_variable_names: Option>, in_unit_decl: bool, in_conditional: bool, in_equation: bool, @@ -33,7 +33,7 @@ pub(crate) fn analyse_stmt( equation_variable: None, in_integral: false, in_sum_prod: false, - sum_variable_name: None, + sum_variable_names: None, in_unit_decl: false, in_conditional: false, in_equation: false, @@ -606,8 +606,11 @@ fn build_fn_call( let is_sum_prod = identifier.pure_name == "sum" || identifier.pure_name == "prod"; if is_sum_prod { context.in_sum_prod = true; + if context.sum_variable_names.is_none() { + context.sum_variable_names = Some(Vec::new()); + } } - + // Don't perform equation solving on special functions if is_integral || is_sum_prod { context.in_equation = false; @@ -619,14 +622,15 @@ fn build_fn_call( for (i, argument) in arguments.iter().enumerate() { if i == 0 && context.in_sum_prod { context.in_conditional = true; + let vars = context.sum_variable_names.as_mut().unwrap(); if let Expr::Binary(left, TokenKind::Equals, _) = argument { if let Expr::Var(var_identifier) = &**left { - context.sum_variable_name = Some(var_identifier.pure_name.clone()); + vars.push(var_identifier.pure_name.clone()); } else { - context.sum_variable_name = Some(String::from("n")); + vars.push(String::from("n")); } } else { - context.sum_variable_name = Some(String::from("n")); + vars.push(String::from("n")); } } @@ -659,6 +663,8 @@ fn build_fn_call( if is_sum_prod { context.in_sum_prod = prev_in_sum_prod; + let vars = context.sum_variable_names.as_mut().unwrap(); + vars.pop(); } Ok(Expr::FnCall(identifier, arguments)) @@ -769,7 +775,7 @@ fn build_var(context: &mut Context, name: &str) -> Expr { } } - if context.in_sum_prod && name == context.sum_variable_name.as_ref().unwrap() { + if context.in_sum_prod && context.sum_variable_names.as_ref().unwrap().contains(&name.to_string()) { return Expr::Var(Identifier::from_full_name(name)); } diff --git a/kalk/src/interpreter.rs b/kalk/src/interpreter.rs index ef512e4..ce76ae5 100644 --- a/kalk/src/interpreter.rs +++ b/kalk/src/interpreter.rs @@ -14,8 +14,7 @@ pub struct Context<'a> { angle_unit: String, #[cfg(feature = "rug")] precision: u32, - sum_variable_name: Option, - sum_variable_value: Option, + sum_variables: Option>, #[cfg(not(target_arch = "wasm32"))] timeout: Option, #[cfg(not(target_arch = "wasm32"))] @@ -34,8 +33,7 @@ impl<'a> Context<'a> { symbol_table, #[cfg(feature = "rug")] precision, - sum_variable_name: None, - sum_variable_value: None, + sum_variables: None, #[cfg(not(target_arch = "wasm32"))] timeout, #[cfg(not(target_arch = "wasm32"))] @@ -77,6 +75,11 @@ impl<'a> Context<'a> { } } +struct SumVar { + name: String, + value: i128, +} + fn eval_stmt(context: &mut Context, stmt: &Stmt) -> Result { match stmt { Stmt::VarDecl(_, _) => eval_var_decl_stmt(context, stmt), @@ -248,9 +251,10 @@ fn eval_var_expr( return eval_expr(context, &Expr::Literal(*value), unit); } - if let Some(sum_variable_name) = &context.sum_variable_name { - if &identifier.full_name == sum_variable_name { - return Ok(KalkValue::from(context.sum_variable_value.unwrap())); + if let Some(sum_variables) = &context.sum_variables { + let sum_variable = sum_variables.iter().find(|x| x.name == identifier.full_name); + if let Some(sum_variable) = sum_variable { + return Ok(KalkValue::from(sum_variable.value)); } } @@ -434,7 +438,14 @@ pub(crate) fn eval_fn_call_expr( ("n", &expressions[0]) }; - context.sum_variable_name = Some(var_name.into()); + if context.sum_variables.is_none() { + context.sum_variables = Some(Vec::new()); + } + + { + let sum_variables = context.sum_variables.as_mut().unwrap(); + sum_variables.push(SumVar { name: var_name.into(), value: 0 }); + } let start = eval_expr(context, start_expr, "")?.to_f64() as i128; let end = eval_expr(context, &expressions[1], "")?.to_f64() as i128; @@ -450,7 +461,9 @@ pub(crate) fn eval_fn_call_expr( }; for n in start..=end { - context.sum_variable_value = Some(n); + let sum_variables = context.sum_variables.as_mut().unwrap(); + sum_variables.last_mut().unwrap().value = n; + let eval = eval_expr(context, &expressions[2], "")?; if sum_else_prod { sum = sum.add(context, eval); @@ -459,8 +472,9 @@ pub(crate) fn eval_fn_call_expr( } } - context.sum_variable_name = None; - context.sum_variable_value = None; + let sum_variables = context.sum_variables.as_mut().unwrap(); + sum_variables.pop(); + let (sum_real, sum_imaginary, _) = as_number_or_zero!(sum); // Set the unit as well diff --git a/tests/sum.kalker b/tests/sum.kalker index dd9c7b1..0975587 100644 --- a/tests/sum.kalker +++ b/tests/sum.kalker @@ -1,3 +1,4 @@ n = 10 sum(1, 5, 2n) = 30 and n = 10 and -sum(k=1, 5, 2k) = 30 \ No newline at end of file +sum(k=1, 5, 2k) = 30 and +sum(a=1, 3, Σ(b=1, 3, a + b)) = 36 \ No newline at end of file