Add max recursion depth check

This commit is contained in:
Kevin Xiao 2023-07-18 20:21:17 -04:00
parent 9b55f89442
commit 3e14a747fd
No known key found for this signature in database
2 changed files with 13 additions and 0 deletions

View File

@ -19,6 +19,7 @@ pub enum KalkError {
InvalidNumberLiteral(String),
InvalidOperator,
InvalidUnit,
StackOverflow,
TimedOut,
VariableReferencesItself,
PiecewiseConditionsAreFalse,
@ -61,6 +62,7 @@ impl ToString for KalkError {
KalkError::InvalidNumberLiteral(x) => format!("Invalid number literal: '{}'.", x),
KalkError::InvalidOperator => String::from("Invalid operator."),
KalkError::InvalidUnit => String::from("Invalid unit."),
KalkError::StackOverflow => String::from("Operation recursed too deeply."),
KalkError::TimedOut => String::from("Operation took too long."),
KalkError::VariableReferencesItself => String::from("Variable references itself."),
KalkError::PiecewiseConditionsAreFalse => String::from("All the conditions in the piecewise are false."),

View File

@ -9,6 +9,8 @@ use crate::symbol_table::SymbolTable;
use crate::{as_number_or_zero, numerical};
use crate::{float, prelude};
const DEFAULT_MAX_RECURSION_DEPTH: u32 = 128;
pub struct Context<'a> {
pub symbol_table: &'a mut SymbolTable,
angle_unit: String,
@ -20,6 +22,8 @@ pub struct Context<'a> {
#[cfg(not(target_arch = "wasm32"))]
start_time: std::time::SystemTime,
is_approximation: bool,
recursion_depth: u32,
max_recursion_depth: u32,
}
impl<'a> Context<'a> {
@ -40,6 +44,8 @@ impl<'a> Context<'a> {
#[cfg(not(target_arch = "wasm32"))]
start_time: std::time::SystemTime::now(),
is_approximation: false,
recursion_depth: 0,
max_recursion_depth: DEFAULT_MAX_RECURSION_DEPTH,
}
}
@ -83,6 +89,7 @@ struct SumVar {
}
fn eval_stmt(context: &mut Context, stmt: &Stmt) -> Result<KalkValue, KalkError> {
context.recursion_depth += 1;
match stmt {
Stmt::VarDecl(_, _) => eval_var_decl_stmt(context, stmt),
Stmt::FnDecl(_, _, _) => eval_fn_decl_stmt(),
@ -506,6 +513,10 @@ pub(crate) fn eval_fn_call_expr(
match stmt_definition {
Some(Stmt::FnDecl(_, arguments, fn_body)) => {
if context.recursion_depth > context.max_recursion_depth {
return Err(KalkError::StackOverflow);
}
if arguments.len() != expressions.len() {
return Err(KalkError::IncorrectAmountOfArguments(
arguments.len(),