diff --git a/kalk/src/analysis.rs b/kalk/src/analysis.rs index 57d51be..edfd7d2 100644 --- a/kalk/src/analysis.rs +++ b/kalk/src/analysis.rs @@ -14,6 +14,7 @@ pub(crate) struct Context<'a> { equation_variable: Option, in_integral: bool, in_sum_prod: bool, + sum_variable_name: Option, in_unit_decl: bool, in_conditional: bool, in_equation: bool, @@ -32,6 +33,7 @@ pub(crate) fn analyse_stmt( equation_variable: None, in_integral: false, in_sum_prod: false, + sum_variable_name: None, in_unit_decl: false, in_conditional: false, in_equation: false, @@ -612,8 +614,22 @@ fn build_fn_call( let arguments = match adjacent_expr { Expr::Vector(arguments) => { let mut new_arguments = Vec::new(); - for argument in arguments { - new_arguments.push(analyse_expr(context, argument)?); + for (i, argument) in arguments.iter().enumerate() { + if i == 0 && context.in_sum_prod { + context.in_conditional = true; + if let Expr::Binary(left, TokenKind::Equals, _) = argument { + if let Expr::Var(var_identifier) = &**left { + context.sum_variable_name = Some(var_identifier.pure_name.clone()); + } else { + context.sum_variable_name = Some(String::from("n")); + } + } else { + context.sum_variable_name = Some(String::from("n")); + } + } + + new_arguments.push(analyse_expr(context, argument.to_owned())?); + context.in_conditional = false; } new_arguments @@ -751,7 +767,7 @@ fn build_var(context: &mut Context, name: &str) -> Expr { } } - if context.in_sum_prod && name == "n" { + if context.in_sum_prod && name == context.sum_variable_name.as_ref().unwrap() { return Expr::Var(Identifier::from_full_name(name)); } diff --git a/kalk/src/interpreter.rs b/kalk/src/interpreter.rs index 72b0ae4..ef512e4 100644 --- a/kalk/src/interpreter.rs +++ b/kalk/src/interpreter.rs @@ -14,7 +14,8 @@ pub struct Context<'a> { angle_unit: String, #[cfg(feature = "rug")] precision: u32, - sum_n_value: Option, + sum_variable_name: Option, + sum_variable_value: Option, #[cfg(not(target_arch = "wasm32"))] timeout: Option, #[cfg(not(target_arch = "wasm32"))] @@ -33,7 +34,8 @@ impl<'a> Context<'a> { symbol_table, #[cfg(feature = "rug")] precision, - sum_n_value: None, + sum_variable_name: None, + sum_variable_value: None, #[cfg(not(target_arch = "wasm32"))] timeout, #[cfg(not(target_arch = "wasm32"))] @@ -246,9 +248,9 @@ fn eval_var_expr( return eval_expr(context, &Expr::Literal(*value), unit); } - if identifier.full_name == "n" { - if let Some(value) = context.sum_n_value { - return Ok(KalkValue::from(value)); + if let Some(sum_variable_name) = &context.sum_variable_name { + if &identifier.full_name == sum_variable_name { + return Ok(KalkValue::from(context.sum_variable_value.unwrap())); } } @@ -422,7 +424,19 @@ pub(crate) fn eval_fn_call_expr( )); } - let start = eval_expr(context, &expressions[0], "")?.to_f64() as i128; + let (var_name, start_expr) = if let Expr::Binary(left, TokenKind::Equals, right) = &expressions[0] { + if let Expr::Var(var_identifier) = &**left { + (var_identifier.pure_name.as_ref(), &**right) + } else { + ("n", &**right) + } + } else { + ("n", &expressions[0]) + }; + + context.sum_variable_name = Some(var_name.into()); + + let start = eval_expr(context, start_expr, "")?.to_f64() as i128; let end = eval_expr(context, &expressions[1], "")?.to_f64() as i128; let sum_else_prod = match identifier.full_name.as_ref() { "sum" => true, @@ -436,7 +450,7 @@ pub(crate) fn eval_fn_call_expr( }; for n in start..=end { - context.sum_n_value = Some(n); + context.sum_variable_value = Some(n); let eval = eval_expr(context, &expressions[2], "")?; if sum_else_prod { sum = sum.add(context, eval); @@ -445,7 +459,8 @@ pub(crate) fn eval_fn_call_expr( } } - context.sum_n_value = None; + context.sum_variable_name = None; + context.sum_variable_value = None; let (sum_real, sum_imaginary, _) = as_number_or_zero!(sum); // Set the unit as well diff --git a/tests/sum.kalker b/tests/sum.kalker index b0d1c8b..dd9c7b1 100644 --- a/tests/sum.kalker +++ b/tests/sum.kalker @@ -1,2 +1,3 @@ n = 10 -sum(1, 5, 2n) = 30 and n = 10 \ No newline at end of file +sum(1, 5, 2n) = 30 and n = 10 and +sum(k=1, 5, 2k) = 30 \ No newline at end of file