diff --git a/src/interpreter.rs b/src/interpreter.rs index 324dc77..7528d70 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -90,7 +90,7 @@ fn eval_unary_expr(context: &mut Context, op: &TokenKind, expr: &Expr) -> Result match op { TokenKind::Minus => Ok(-expr_value), - TokenKind::Exclamation => Ok(prelude::funcs::factorial(expr_value)), + TokenKind::Exclamation => Ok(prelude::special_funcs::factorial(expr_value as i32) as f64), _ => Err(String::from("Invalid operator for unary expression.")), } } @@ -144,6 +144,7 @@ fn eval_fn_call_expr( identifier: &str, expressions: &Vec, ) -> Result { + // Prelude let prelude_func = match expressions.len() { 1 => { let x = eval_expr(context, &expressions[0])?; @@ -161,6 +162,38 @@ fn eval_fn_call_expr( return Ok(result); } + // Special functions + match identifier { + "sum" => { + // Make sure exactly 3 arguments were supplied. + if expressions.len() != 3 { + return Err(format!( + "Expected 3 arguments but got {}.", + expressions.len() + )); + } + + let start = eval_expr(context, &expressions[0])? as i32; + let end = eval_expr(context, &expressions[1])? as i32; + let mut sum = 0f64; + + for n in start..=end { + let n_expr = Expr::Literal(String::from(n.to_string())); + + // Update the variable "n" in the symbol table on every iteration, + // then calculate the expression and add it to the total sum. + context + .symbol_table + .set("n", Stmt::VarDecl(String::from("n"), Box::new(n_expr))); + sum += eval_expr(context, &expressions[2])?; + } + + return Ok(sum); + } + _ => (), + } + + // Symbol Table let stmt_definition = context .symbol_table .get(&format!("{}()", identifier)) diff --git a/src/prelude.rs b/src/prelude.rs index fd3b6fa..3885fdc 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -125,7 +125,18 @@ fn from_angle_unit(x: f64, angle_unit: &Unit) -> f64 { } } -pub mod funcs { +pub mod special_funcs { + pub fn factorial(x: i32) -> i32 { + let mut value = 1; + for i in 1..=x { + value *= i; + } + + value + } +} + +mod funcs { pub fn abs(x: f64) -> f64 { x.abs() } @@ -213,15 +224,6 @@ pub mod funcs { x.exp() } - pub fn factorial(x: f64) -> f64 { - let mut value = 1; - for i in 1..=x as i32 { - value *= i; - } - - value as f64 - } - pub fn floor(x: f64) -> f64 { x.floor() } diff --git a/src/symbol_table.rs b/src/symbol_table.rs index 213683e..094da65 100644 --- a/src/symbol_table.rs +++ b/src/symbol_table.rs @@ -20,6 +20,14 @@ impl SymbolTable { self.hashmap.get(key) } + pub fn set(&mut self, key: &str, value: Stmt) { + if let Some(stmt) = self.hashmap.get_mut(key) { + *stmt = value; + } else { + self.insert(key, value); + } + } + pub fn contains_var(&self, identifier: &str) -> bool { prelude::CONSTANTS.contains_key(identifier) || self.hashmap.contains_key(identifier) }