diff --git a/.gitignore b/.gitignore index d26f037..c7f0aeb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target */target kalk/Cargo.lock kalk_cli/test +.vscode/ diff --git a/Cargo.lock b/Cargo.lock index 0e14ce6..0a3da2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,6 +140,7 @@ dependencies = [ name = "kalk" version = "0.1.11" dependencies = [ + "lazy_static", "phf", "regex", "rug", diff --git a/kalk/Cargo.toml b/kalk/Cargo.toml index fb39286..71577e8 100644 --- a/kalk/Cargo.toml +++ b/kalk/Cargo.toml @@ -15,3 +15,4 @@ phf = { version = "0.8", features = ["macros"] } rug = "1.9.0" test-case = "1.0.0" regex = "1" +lazy_static = "1.4.0" diff --git a/kalk/src/ast.rs b/kalk/src/ast.rs index b3bb174..5aad8b4 100644 --- a/kalk/src/ast.rs +++ b/kalk/src/ast.rs @@ -1,12 +1,11 @@ use crate::lexer::TokenKind; -use crate::parser::CalcError; -use crate::parser::Unit; /// A tree structure of a statement. #[derive(Debug, Clone, PartialEq)] pub enum Stmt { VarDecl(String, Box), FnDecl(String, Vec, Box), + UnitDecl(String, String, Box), /// For simplicity, expressions can be put into statements. This is the form in which expressions are passed to the interpreter. Expr(Box), } @@ -16,26 +15,9 @@ pub enum Stmt { pub enum Expr { Binary(Box, TokenKind, Box), Unary(TokenKind, Box), - Unit(Box, TokenKind), + Unit(String, Box), Var(String), Group(Box), FnCall(String, Vec), Literal(String), } - -impl TokenKind { - pub fn is_unit(&self) -> bool { - match self { - TokenKind::Deg | TokenKind::Rad => true, - _ => false, - } - } - - pub fn to_unit(&self) -> Result { - match self { - TokenKind::Deg => Ok(Unit::Degrees), - TokenKind::Rad => Ok(Unit::Radians), - _ => Err(CalcError::InvalidUnit), - } - } -} diff --git a/kalk/src/interpreter.rs b/kalk/src/interpreter.rs index 931dd52..f08b050 100644 --- a/kalk/src/interpreter.rs +++ b/kalk/src/interpreter.rs @@ -1,7 +1,7 @@ use crate::ast::{Expr, Stmt}; use crate::lexer::TokenKind; use crate::parser::CalcError; -use crate::parser::Unit; +use crate::parser::DECL_UNIT; use crate::prelude; use crate::symbol_table::SymbolTable; use rug::ops::Pow; @@ -9,20 +9,23 @@ use rug::Float; pub struct Context<'a> { symbol_table: &'a mut SymbolTable, - angle_unit: Unit, + angle_unit: String, precision: u32, } impl<'a> Context<'a> { - pub fn new(symbol_table: &'a mut SymbolTable, angle_unit: &Unit, precision: u32) -> Self { + pub fn new(symbol_table: &'a mut SymbolTable, angle_unit: &str, precision: u32) -> Self { Context { - angle_unit: angle_unit.clone(), + angle_unit: angle_unit.into(), symbol_table, precision, } } - pub fn interpret(&mut self, statements: Vec) -> Result, CalcError> { + pub fn interpret( + &mut self, + statements: Vec, + ) -> Result, CalcError> { for (i, stmt) in statements.iter().enumerate() { let value = eval_stmt(self, stmt); @@ -46,72 +49,110 @@ impl<'a> Context<'a> { } } -fn eval_stmt(context: &mut Context, stmt: &Stmt) -> Result { +fn eval_stmt(context: &mut Context, stmt: &Stmt) -> Result<(Float, String), CalcError> { match stmt { - Stmt::VarDecl(identifier, _) => eval_var_decl_stmt(context, stmt, identifier), + Stmt::VarDecl(_, _) => eval_var_decl_stmt(context, stmt), Stmt::FnDecl(_, _, _) => eval_fn_decl_stmt(context), + Stmt::UnitDecl(_, _, _) => eval_unit_decl_stmt(context), Stmt::Expr(expr) => eval_expr_stmt(context, &expr), } } -fn eval_var_decl_stmt( - context: &mut Context, - stmt: &Stmt, - identifier: &str, -) -> Result { - context.symbol_table.insert(&identifier, stmt.clone()); - Ok(Float::with_val(context.precision, 1)) +fn eval_var_decl_stmt(context: &mut Context, stmt: &Stmt) -> Result<(Float, String), CalcError> { + context.symbol_table.insert(stmt.clone()); + Ok((Float::with_val(context.precision, 1), String::new())) } -fn eval_fn_decl_stmt(context: &mut Context) -> Result { - Ok(Float::with_val(context.precision, 1)) // Nothing needs to happen here, since the parser will already have added the FnDecl's to the symbol table. +fn eval_fn_decl_stmt(context: &mut Context) -> Result<(Float, String), CalcError> { + Ok((Float::with_val(context.precision, 1), String::new())) // Nothing needs to happen here, since the parser will already have added the FnDecl's to the symbol table. } -fn eval_expr_stmt(context: &mut Context, expr: &Expr) -> Result { - eval_expr(context, &expr) +fn eval_unit_decl_stmt(context: &mut Context) -> Result<(Float, String), CalcError> { + Ok((Float::with_val(context.precision, 1), String::new())) } -fn eval_expr(context: &mut Context, expr: &Expr) -> Result { +fn eval_expr_stmt(context: &mut Context, expr: &Expr) -> Result<(Float, String), CalcError> { + eval_expr(context, &expr, "") +} + +fn eval_expr(context: &mut Context, expr: &Expr, unit: &str) -> Result<(Float, String), CalcError> { match expr { - Expr::Binary(left, op, right) => eval_binary_expr(context, &left, op, &right), - Expr::Unary(op, expr) => eval_unary_expr(context, op, expr), - Expr::Unit(expr, kind) => eval_unit_expr(context, expr, kind), - Expr::Var(identifier) => eval_var_expr(context, identifier), - Expr::Literal(value) => eval_literal_expr(context, value), - Expr::Group(expr) => eval_group_expr(context, &expr), + Expr::Binary(left, op, right) => eval_binary_expr(context, &left, op, &right, unit), + Expr::Unary(op, expr) => eval_unary_expr(context, op, expr, unit), + Expr::Unit(identifier, expr) => eval_unit_expr(context, identifier, expr), + Expr::Var(identifier) => eval_var_expr(context, identifier, unit), + Expr::Literal(value) => eval_literal_expr(context, value, unit), + Expr::Group(expr) => eval_group_expr(context, &expr, unit), Expr::FnCall(identifier, expressions) => { - eval_fn_call_expr(context, identifier, expressions) + eval_fn_call_expr(context, identifier, expressions, unit) } } } fn eval_binary_expr( context: &mut Context, - left: &Expr, + left_expr: &Expr, op: &TokenKind, - right: &Expr, -) -> Result { - let left = eval_expr(context, &left)?; - let right = eval_expr(context, &right)?; + right_expr: &Expr, + unit: &str, +) -> Result<(Float, String), CalcError> { + if let TokenKind::ToKeyword = op { + // TODO: When the unit conversion function takes a Float instead of Expr, + // move this to the match statement further down. + if let Expr::Var(right_unit) = right_expr { + let (_, left_unit) = eval_expr(context, left_expr, "")?; + return convert_unit(context, left_expr, &left_unit, &right_unit); // TODO: Avoid evaluating this twice. + } + } - Ok(match op { - TokenKind::Plus => left + right, - TokenKind::Minus => left - right, - TokenKind::Star => left * right, - TokenKind::Slash => left / right, - TokenKind::Power => left.pow(right), - _ => Float::with_val(1, 1), - }) + let (left, left_unit) = eval_expr(context, left_expr, "")?; + let (right, _) = if left_unit.len() > 0 { + let (_, right_unit) = eval_expr(context, right_expr, "")?; // TODO: Avoid evaluating this twice. + + if right_unit.len() > 0 { + convert_unit(context, right_expr, &right_unit, &left_unit)? + } else { + eval_expr(context, right_expr, unit)? + } + } else { + eval_expr(context, right_expr, unit)? + }; + + let final_unit = if unit.len() == 0 { + left_unit + } else { + unit.into() + }; + + Ok(( + match op { + TokenKind::Plus => left + right, + TokenKind::Minus => left - right, + TokenKind::Star => left * right, + TokenKind::Slash => left / right, + TokenKind::Power => left.pow(right), + _ => Float::with_val(1, 1), + }, + final_unit, + )) } -fn eval_unary_expr(context: &mut Context, op: &TokenKind, expr: &Expr) -> Result { - let expr_value = eval_expr(context, &expr)?; +fn eval_unary_expr( + context: &mut Context, + op: &TokenKind, + expr: &Expr, + unit: &str, +) -> Result<(Float, String), CalcError> { + let (expr_value, unit) = eval_expr(context, &expr, unit)?; match op { - TokenKind::Minus => Ok(-expr_value), - TokenKind::Exclamation => Ok(Float::with_val( - context.precision, - prelude::special_funcs::factorial(expr_value), + TokenKind::Minus => Ok((-expr_value, unit)), + TokenKind::Exclamation => Ok(( + Float::with_val( + context.precision, + prelude::special_funcs::factorial(expr_value), + ), + unit, )), _ => Err(CalcError::InvalidOperator), } @@ -119,73 +160,105 @@ fn eval_unary_expr(context: &mut Context, op: &TokenKind, expr: &Expr) -> Result fn eval_unit_expr( context: &mut Context, + identifier: &str, expr: &Expr, - kind: &TokenKind, -) -> Result { - let x = eval_expr(context, &expr); - let unit = kind.to_unit()?; - - // Don't do any angle conversions if the defauly angle unit is the same as the unit kind - match unit { - Unit::Degrees | Unit::Radians => { - if context.angle_unit == unit { - return x; - } - } +) -> Result<(Float, String), CalcError> { + let angle_unit = &context.angle_unit.clone(); + if (identifier == "rad" || identifier == "deg") && angle_unit != identifier { + return convert_unit(context, expr, identifier, angle_unit); } - match unit { - Unit::Degrees => Ok(prelude::special_funcs::to_radians(x?)), - Unit::Radians => Ok(prelude::special_funcs::to_degrees(x?)), + eval_expr(context, expr, identifier) +} + +pub fn convert_unit( + context: &mut Context, + expr: &Expr, + from_unit: &str, + to_unit: &str, +) -> Result<(Float, String), CalcError> { + if let Some(Stmt::UnitDecl(_, _, unit_def)) = + context.symbol_table.get_unit(to_unit, from_unit).cloned() + { + context + .symbol_table + .insert(Stmt::VarDecl(DECL_UNIT.into(), Box::new(expr.clone()))); + + Ok((eval_expr(context, &unit_def, "")?.0, to_unit.into())) + } else { + Err(CalcError::InvalidUnit) } } -fn eval_var_expr(context: &mut Context, identifier: &str) -> Result { +fn eval_var_expr( + context: &mut Context, + identifier: &str, + unit: &str, +) -> Result<(Float, String), CalcError> { // If there is a constant with this name, return a literal expression with its value if let Some(value) = prelude::CONSTANTS.get(identifier) { - return eval_expr(context, &Expr::Literal((*value).to_string())); + return eval_expr(context, &Expr::Literal((*value).to_string()), unit); } // Look for the variable in the symbol table - let var_decl = context.symbol_table.get(identifier).cloned(); + let var_decl = context.symbol_table.get_var(identifier).cloned(); match var_decl { - Some(Stmt::VarDecl(_, expr)) => eval_expr(context, &expr), + Some(Stmt::VarDecl(_, expr)) => eval_expr(context, &expr, unit), _ => Err(CalcError::UndefinedVar(identifier.into())), } } -fn eval_literal_expr(context: &mut Context, value: &str) -> Result { +fn eval_literal_expr( + context: &mut Context, + value: &str, + unit: &str, +) -> Result<(Float, String), CalcError> { match Float::parse(value) { - Ok(parsed_value) => Ok(Float::with_val(context.precision, parsed_value)), + Ok(parsed_value) => Ok(( + Float::with_val(context.precision, parsed_value), + unit.into(), + )), Err(_) => Err(CalcError::InvalidNumberLiteral(value.into())), } } -fn eval_group_expr(context: &mut Context, expr: &Expr) -> Result { - eval_expr(context, expr) +fn eval_group_expr( + context: &mut Context, + expr: &Expr, + unit: &str, +) -> Result<(Float, String), CalcError> { + eval_expr(context, expr, unit) } fn eval_fn_call_expr( context: &mut Context, identifier: &str, expressions: &[Expr], -) -> Result { + unit: &str, +) -> Result<(Float, String), CalcError> { // Prelude let prelude_func = match expressions.len() { 1 => { - let x = eval_expr(context, &expressions[0])?; - prelude::call_unary_func(identifier, x, &context.angle_unit) + let x = eval_expr(context, &expressions[0], "")?.0; + prelude::call_unary_func(context, identifier, x, &context.angle_unit.clone()) } 2 => { - let x = eval_expr(context, &expressions[0])?; - let y = eval_expr(context, &expressions[1])?; - prelude::call_binary_func(identifier, x, y, &context.angle_unit) + let x = eval_expr(context, &expressions[0], "")?.0; + let y = eval_expr(context, &expressions[1], "")?.0; + prelude::call_binary_func(context, identifier, x, y, &context.angle_unit.clone()) } _ => None, }; - if let Some(result) = prelude_func { - return Ok(result); + if let Some((result, func_unit)) = prelude_func { + return Ok(( + result, + if unit.len() > 0 { + unit.into() + } else { + func_unit.into() + }, + )); } // Special functions @@ -200,8 +273,8 @@ fn eval_fn_call_expr( )); } - let start = eval_expr(context, &expressions[0])?.to_f64() as i128; - let end = eval_expr(context, &expressions[1])?.to_f64() as i128; + let start = eval_expr(context, &expressions[0], "")?.0.to_f64() as i128; + let end = eval_expr(context, &expressions[1], "")?.0.to_f64() as i128; let mut sum = Float::with_val(context.precision, 0); for n in start..=end { @@ -211,20 +284,17 @@ fn eval_fn_call_expr( // then calculate the expression and add it to the total sum. context .symbol_table - .set("n", Stmt::VarDecl(String::from("n"), Box::new(n_expr))); - sum += eval_expr(context, &expressions[2])?; + .set(Stmt::VarDecl(String::from("n"), Box::new(n_expr))); + sum += eval_expr(context, &expressions[2], "")?.0; } - return Ok(sum); + return Ok((sum, unit.into())); } _ => (), } // Symbol Table - let stmt_definition = context - .symbol_table - .get(&format!("{}()", identifier)) - .cloned(); + let stmt_definition = context.symbol_table.get_fn(identifier).cloned(); match stmt_definition { Some(Stmt::FnDecl(_, arguments, fn_body)) => { @@ -244,7 +314,7 @@ fn eval_fn_call_expr( )?; } - eval_expr(context, &*fn_body) + eval_expr(context, &fn_body, unit) } _ => Err(CalcError::UndefinedFn(identifier.into())), } @@ -259,12 +329,54 @@ mod tests { const PRECISION: u32 = 53; - fn interpret(stmt: Stmt) -> Result, CalcError> { + lazy_static::lazy_static! { + static ref DEG_RAD_UNIT: Stmt = unit_decl( + "deg", + "rad", + binary( + binary( + var(crate::parser::DECL_UNIT), + TokenKind::Star, + literal("180"), + ), + TokenKind::Slash, + var("pi"), + ), + ); + static ref RAD_DEG_UNIT: Stmt = unit_decl( + "rad", + "deg", + binary( + binary(var(crate::parser::DECL_UNIT), TokenKind::Star, var("pi")), + TokenKind::Slash, + literal("180"), + ), + ); + } + + fn interpret_with_unit(stmt: Stmt) -> Result, CalcError> { let mut symbol_table = SymbolTable::new(); - let mut context = Context::new(&mut symbol_table, &Unit::Radians, PRECISION); + symbol_table + .insert(DEG_RAD_UNIT.clone()) + .insert(RAD_DEG_UNIT.clone()); + + let mut context = Context::new(&mut symbol_table, "rad", PRECISION); context.interpret(vec![stmt]) } + fn interpret(stmt: Stmt) -> Result, CalcError> { + if let Some((result, _)) = interpret_with_unit(stmt)? { + Ok(Some(result)) + } else { + Ok(None) + } + } + + fn cmp(x: Float, y: f64) -> bool { + println!("{} = {}", x.to_f64(), y); + (x.to_f64() - y).abs() < 0.0001 + } + #[test] fn test_literal() { let stmt = Stmt::Expr(literal("1")); @@ -294,25 +406,43 @@ mod tests { fn test_unary() { let neg = Stmt::Expr(unary(Minus, literal("1"))); let fact = Stmt::Expr(unary(Exclamation, literal("5"))); - let fact_dec = Stmt::Expr(unary(Exclamation, literal("5.2"))); assert_eq!(interpret(neg).unwrap().unwrap(), -1); assert_eq!(interpret(fact).unwrap().unwrap(), 120); - - let fact_dec_result = interpret(fact_dec).unwrap().unwrap(); - assert!(fact_dec_result > 169.406 && fact_dec_result < 169.407); } #[test] - fn test_unit() { - let rad = Stmt::Expr(Box::new(Expr::Unit(literal("1"), Rad))); - let deg = Stmt::Expr(Box::new(Expr::Unit(literal("1"), Deg))); + fn test_angle_units() { + let rad_explicit = Stmt::Expr(fn_call("sin", vec![*unit("rad", literal("1"))])); + let deg_explicit = Stmt::Expr(fn_call("sin", vec![*unit("deg", literal("1"))])); + let implicit = Stmt::Expr(fn_call("sin", vec![*literal("1")])); - assert_eq!(interpret(rad).unwrap().unwrap(), 1); - assert!( - (interpret(deg).unwrap().unwrap() - Float::with_val(PRECISION, 0.017456)).abs() - < Float::with_val(PRECISION, 0.0001) - ); + assert!(cmp(interpret(rad_explicit).unwrap().unwrap(), 0.84147098)); + assert!(cmp(interpret(deg_explicit).unwrap().unwrap(), 0.01745240)); + + let mut rad_symbol_table = SymbolTable::new(); + rad_symbol_table + .insert(DEG_RAD_UNIT.clone()) + .insert(RAD_DEG_UNIT.clone()); + let mut deg_symbol_table = SymbolTable::new(); + deg_symbol_table + .insert(DEG_RAD_UNIT.clone()) + .insert(RAD_DEG_UNIT.clone()); + let mut rad_context = Context::new(&mut rad_symbol_table, "rad", PRECISION); + let mut deg_context = Context::new(&mut deg_symbol_table, "deg", PRECISION); + + assert!(cmp( + rad_context + .interpret(vec![implicit.clone()]) + .unwrap() + .unwrap() + .0, + 0.84147098 + )); + assert!(cmp( + deg_context.interpret(vec![implicit]).unwrap().unwrap().0, + 0.01745240 + )); } #[test] @@ -321,10 +451,10 @@ mod tests { // Prepare by inserting a variable declaration in the symbol table. let mut symbol_table = SymbolTable::new(); - symbol_table.insert("x", var_decl("x", literal("1"))); + symbol_table.insert(var_decl("x", literal("1"))); - let mut context = Context::new(&mut symbol_table, &Unit::Radians, PRECISION); - assert_eq!(context.interpret(vec![stmt]).unwrap().unwrap(), 1); + let mut context = Context::new(&mut symbol_table, "rad", PRECISION); + assert_eq!(context.interpret(vec![stmt]).unwrap().unwrap().0, 1); } #[test] @@ -341,7 +471,7 @@ mod tests { fn test_var_decl() { let stmt = var_decl("x", literal("1")); let mut symbol_table = SymbolTable::new(); - Context::new(&mut symbol_table, &Unit::Radians, PRECISION) + Context::new(&mut symbol_table, "rad", PRECISION) .interpret(vec![stmt]) .unwrap(); @@ -354,17 +484,14 @@ mod tests { // Prepare by inserting a variable declaration in the symbol table. let mut symbol_table = SymbolTable::new(); - symbol_table.insert( - "f()", - fn_decl( - "f", - vec![String::from("x")], - binary(var("x"), TokenKind::Plus, literal("2")), - ), - ); + symbol_table.insert(fn_decl( + "f", + vec![String::from("x")], + binary(var("x"), TokenKind::Plus, literal("2")), + )); - let mut context = Context::new(&mut symbol_table, &Unit::Radians, PRECISION); - assert_eq!(context.interpret(vec![stmt]).unwrap().unwrap(), 3); + let mut context = Context::new(&mut symbol_table, "rad", PRECISION); + assert_eq!(context.interpret(vec![stmt]).unwrap().unwrap().0, 3); } #[test] diff --git a/kalk/src/inverter.rs b/kalk/src/inverter.rs new file mode 100644 index 0000000..e62491b --- /dev/null +++ b/kalk/src/inverter.rs @@ -0,0 +1,540 @@ +use crate::ast::{Expr, Stmt}; +use crate::lexer::TokenKind; +use crate::parser::CalcError; +use crate::parser::DECL_UNIT; +use crate::prelude; +use crate::symbol_table::SymbolTable; + +pub const INVERSE_UNARY_FUNCS: phf::Map<&'static str, &'static str> = phf::phf_map! { + "cos" => "acos", + "cosec" => "acosec", + "cosech" => "cosech", + "cosh" => "acosh", + "cot" => "acot", + "coth" => "acoth", + "sec" => "asec", + "sech" => "asech", + "sin" => "asin", + "sinh" => "asinh", + "tan" => "atan", + "tanh" => "atanh", + + "acos" => "cos", + "acosec" => "cosec", + "acosech" => "cosech", + "acosh" => "cosh", + "acot" => "cot", + "acoth" => "coth", + "asec" => "sec", + "asech" => "sech", + "asin" => "sin", + "asinh" => "sinh", + "atan" => "tan", + "atanh" => "tanh", +}; + +impl Expr { + pub fn invert(&self, symbol_table: &mut SymbolTable) -> Result { + let target_expr = Expr::Var(DECL_UNIT.into()); + let result = invert(target_expr, symbol_table, self); + + Ok(result?.0) + } +} + +fn invert( + target_expr: Expr, + symbol_table: &mut SymbolTable, + expr: &Expr, +) -> Result<(Expr, Expr), CalcError> { + match expr { + Expr::Binary(left, op, right) => { + invert_binary(target_expr, symbol_table, &left, op, &right) + } + Expr::Unary(op, expr) => invert_unary(target_expr, op, &expr), + Expr::Unit(identifier, expr) => invert_unit(target_expr, &identifier, &expr), + Expr::Var(identifier) => invert_var(target_expr, symbol_table, identifier), + Expr::Group(expr) => Ok((target_expr, *expr.clone())), + Expr::FnCall(identifier, arguments) => { + invert_fn_call(target_expr, symbol_table, &identifier, arguments) + } + Expr::Literal(_) => Ok((target_expr, expr.clone())), + } +} + +fn invert_binary( + target_expr: Expr, + symbol_table: &mut SymbolTable, + left: &Expr, + op: &TokenKind, + right: &Expr, +) -> Result<(Expr, Expr), CalcError> { + let op_inv = match op { + TokenKind::Plus => TokenKind::Minus, + TokenKind::Minus => { + // Eg. a-(b+c) + // Multiply "-1" into the group, resulting in it becoming a normal expression. Then invert it normally. + if let Expr::Group(inside_group) = right { + return invert_binary( + target_expr, + symbol_table, + left, + &TokenKind::Plus, + &multiply_into(&Expr::Literal(String::from("-1")), inside_group)?, + ); + } + + TokenKind::Plus + } + TokenKind::Star => { + // If the left expression is a group, multiply the right expression into it, dissolving the group. + // It can then be inverted normally. + if let Expr::Group(inside_group) = left { + return invert( + target_expr, + symbol_table, + &multiply_into(right, inside_group)?, + ); + } + + // Same as above but left/right switched. + if let Expr::Group(inside_group) = right { + return invert( + target_expr, + symbol_table, + &multiply_into(left, inside_group)?, + ); + } + + TokenKind::Slash + } + TokenKind::Slash => { + // Eg. (a+b)/c + // Just dissolve the group. Nothing more needs to be done mathematically. + if let Expr::Group(inside_group) = left { + return invert( + target_expr, + symbol_table, + &Expr::Binary(inside_group.clone(), op.clone(), Box::new(right.clone())), + ); + } + + // Eg. a/(b+c) + // Same as above. + if let Expr::Group(inside_group) = right { + return invert( + target_expr, + symbol_table, + &Expr::Binary(Box::new(left.clone()), op.clone(), inside_group.clone()), + ); + } + + TokenKind::Star + } + _ => unreachable!(), + }; + + // If the left expression contains the unit, invert the right one instead, + // since the unit should not be moved. + if contains_the_unit(symbol_table, left) { + // But if the right expression *also* contains the unit, + // throw an error, since it can't handle this yet. + if contains_the_unit(symbol_table, right) { + return Err(CalcError::UnableToInvert(String::from( + "Expressions with several instances of an unknown variable (this might be supported in the future). Try simplifying the expression.", + ))); + } + + return Ok(invert( + Expr::Binary(Box::new(target_expr), op_inv, Box::new(right.clone())), + symbol_table, + left, + )?); + } + + // Otherwise, invert the left side. + let final_target_expr = Expr::Binary(Box::new(target_expr), op_inv, Box::new(left.clone())); + Ok(invert( + // Eg. 2-a + // If the operator is minus (and the left expression is being inverted), + // make the target expression negative to keep balance. + if let TokenKind::Minus = op { + Expr::Unary(TokenKind::Minus, Box::new(final_target_expr)) + } else { + final_target_expr + }, + symbol_table, + right, // Then invert the right expression. + )?) +} + +fn invert_unary(target_expr: Expr, op: &TokenKind, expr: &Expr) -> Result<(Expr, Expr), CalcError> { + match op { + TokenKind::Minus => Ok(( + // Make the target expression negative + Expr::Unary(TokenKind::Minus, Box::new(target_expr)), + expr.clone(), // And then continue inverting the inner-expression. + )), + _ => unimplemented!(), + } +} + +fn invert_unit( + _target_expr: Expr, + _identifier: &str, + _expr: &Expr, +) -> Result<(Expr, Expr), CalcError> { + Err(CalcError::UnableToInvert(String::from( + "Expressions containing other units (this should be supported in the future).", + ))) +} + +fn invert_var( + target_expr: Expr, + symbol_table: &mut SymbolTable, + identifier: &str, +) -> Result<(Expr, Expr), CalcError> { + if identifier == DECL_UNIT { + Ok((target_expr, Expr::Var(identifier.into()))) + } else if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier).cloned() { + invert(target_expr, symbol_table, &var_expr) + } else { + Ok((target_expr, Expr::Var(identifier.into()))) + } +} + +fn invert_fn_call( + target_expr: Expr, + symbol_table: &mut SymbolTable, + identifier: &str, + arguments: &Vec, +) -> Result<(Expr, Expr), CalcError> { + // If prelude function + match arguments.len() { + 1 => { + if prelude::UNARY_FUNCS.contains_key(identifier) { + if let Some(fn_inv) = INVERSE_UNARY_FUNCS.get(identifier) { + return Ok(( + Expr::FnCall(fn_inv.to_string(), vec![target_expr]), + arguments[0].clone(), + )); + } else { + match identifier { + "sqrt" => { + return Ok(( + Expr::Binary( + Box::new(target_expr), + TokenKind::Power, + Box::new(Expr::Literal(String::from("2"))), + ), + arguments[0].clone(), + )); + } + _ => { + return Err(CalcError::UnableToInvert(format!( + "Function '{}'", + identifier + ))); + } + } + } + } + } + 2 => { + if prelude::BINARY_FUNCS.contains_key(identifier) { + return Err(CalcError::UnableToInvert(format!( + "Function '{}'", + identifier + ))); + } + } + _ => (), + } + + // Get the function definition from the symbol table. + let (parameters, body) = + if let Some(Stmt::FnDecl(_, parameters, body)) = symbol_table.get_fn(identifier).cloned() { + (parameters, body) + } else { + return Err(CalcError::UndefinedFn(identifier.into())); + }; + + // Make sure the input is valid. + if parameters.len() != arguments.len() { + return Err(CalcError::IncorrectAmountOfArguments( + parameters.len(), + identifier.into(), + arguments.len(), + )); + } + + // Make the parameters usable as variables inside the function. + let mut parameters_iter = parameters.iter(); + for argument in arguments { + symbol_table.insert(Stmt::VarDecl( + parameters_iter.next().unwrap().to_string(), + Box::new(argument.clone()), + )); + } + + // Invert everything in the function body. + invert(target_expr, symbol_table, &body) +} + +fn contains_the_unit(symbol_table: &SymbolTable, expr: &Expr) -> bool { + // Recursively scan the expression for the unit. + match expr { + Expr::Binary(left, _, right) => { + contains_the_unit(symbol_table, left) || contains_the_unit(symbol_table, right) + } + Expr::Unary(_, expr) => contains_the_unit(symbol_table, expr), + Expr::Unit(_, expr) => contains_the_unit(symbol_table, expr), + Expr::Var(identifier) => { + identifier == DECL_UNIT + || if let Some(Stmt::VarDecl(_, var_expr)) = symbol_table.get_var(identifier) { + contains_the_unit(symbol_table, var_expr) + } else { + false + } + } + Expr::Group(expr) => contains_the_unit(symbol_table, expr), + Expr::FnCall(_, args) => { + for arg in args { + if contains_the_unit(symbol_table, arg) { + return true; + } + } + + false + } + Expr::Literal(_) => false, + } +} + +/// Multiply an expression into a group. +fn multiply_into(expr: &Expr, base_expr: &Expr) -> Result { + match base_expr { + Expr::Binary(left, op, right) => match op { + // If + or -, multiply the expression with each term. + TokenKind::Plus | TokenKind::Minus => Ok(Expr::Binary( + Box::new(multiply_into(expr, &left)?), + op.clone(), + Box::new(multiply_into(expr, &right)?), + )), + // If * or /, only multiply with the first factor. + TokenKind::Star | TokenKind::Slash => Ok(Expr::Binary( + Box::new(multiply_into(expr, &left)?), + op.clone(), + right.clone(), + )), + _ => unimplemented!(), + }, + // If it's a literal, just multiply them together. + Expr::Literal(_) | Expr::Var(_) => Ok(Expr::Binary( + Box::new(expr.clone()), + TokenKind::Star, + Box::new(base_expr.clone()), + )), + Expr::Group(_) => Err(CalcError::UnableToInvert(String::from( + "Parenthesis multiplied with parenthesis (this should be possible in the future).", + ))), + _ => unimplemented!(), + } +} + +#[allow(unused_imports, dead_code)] // Getting warnings for some reason +mod tests { + use crate::ast::Expr; + use crate::lexer::TokenKind::*; + use crate::symbol_table::SymbolTable; + use crate::test_helpers::*; + + fn decl_unit() -> Box { + Box::new(Expr::Var(crate::parser::DECL_UNIT.into())) + } + + #[test] + fn test_binary() { + let ladd = binary(decl_unit(), Plus, literal("1")); + let lsub = binary(decl_unit(), Minus, literal("1")); + let lmul = binary(decl_unit(), Star, literal("1")); + let ldiv = binary(decl_unit(), Slash, literal("1")); + + let radd = binary(literal("1"), Plus, decl_unit()); + let rsub = binary(literal("1"), Minus, decl_unit()); + let rmul = binary(literal("1"), Star, decl_unit()); + let rdiv = binary(literal("1"), Slash, decl_unit()); + + let mut symbol_table = SymbolTable::new(); + assert_eq!( + ladd.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, literal("1")) + ); + assert_eq!( + lsub.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Plus, literal("1")) + ); + assert_eq!( + lmul.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Slash, literal("1")) + ); + assert_eq!( + ldiv.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Star, literal("1")) + ); + + assert_eq!( + radd.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, literal("1")) + ); + assert_eq!( + rsub.invert(&mut symbol_table).unwrap(), + *unary(Minus, binary(decl_unit(), Plus, literal("1"))) + ); + assert_eq!( + rmul.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Slash, literal("1")) + ); + assert_eq!( + rdiv.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Star, literal("1")) + ); + } + + #[test] + fn test_unary() { + let neg = unary(Minus, decl_unit()); + + let mut symbol_table = SymbolTable::new(); + assert_eq!(neg.invert(&mut symbol_table).unwrap(), *neg); + } + + #[test] + fn test_fn_call() { + let call_with_literal = binary(fn_call("f", vec![*literal("2")]), Plus, decl_unit()); + let call_with_decl_unit = fn_call("f", vec![*decl_unit()]); + let call_with_decl_unit_and_literal = + fn_call("f", vec![*binary(decl_unit(), Plus, literal("2"))]); + let decl = fn_decl( + "f", + vec![String::from("x")], + binary(var("x"), Plus, literal("1")), + ); + + let mut symbol_table = SymbolTable::new(); + symbol_table.insert(decl); + assert_eq!( + call_with_literal.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, fn_call("f", vec![*literal("2")])), + ); + assert_eq!( + call_with_decl_unit.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Minus, literal("1")) + ); + assert_eq!( + call_with_decl_unit_and_literal + .invert(&mut symbol_table) + .unwrap(), + *binary( + binary(decl_unit(), Minus, literal("1")), + Minus, + literal("2") + ) + ); + } + + #[test] + fn test_group() { + let group_x = binary( + group(binary(decl_unit(), Plus, literal("3"))), + Star, + literal("2"), + ); + let group_unary_minus = binary( + literal("2"), + Minus, + group(binary(decl_unit(), Plus, literal("3"))), + ); + let x_group_add = binary( + literal("2"), + Star, + group(binary(decl_unit(), Plus, literal("3"))), + ); + let x_group_sub = binary( + literal("2"), + Star, + group(binary(decl_unit(), Minus, literal("3"))), + ); + let x_group_mul = binary( + literal("2"), + Star, + group(binary(decl_unit(), Star, literal("3"))), + ); + let x_group_div = binary( + literal("2"), + Star, + group(binary(decl_unit(), Slash, literal("3"))), + ); + + let mut symbol_table = SymbolTable::new(); + assert_eq!( + group_x.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Minus, binary(literal("2"), Star, literal("3"))), + Slash, + literal("2") + ) + ); + assert_eq!( + group_unary_minus.invert(&mut symbol_table).unwrap(), + *binary( + binary( + binary(decl_unit(), Minus, literal("2")), + Minus, + binary(literal("-1"), Star, literal("3")) + ), + Slash, + literal("-1") + ) + ); + assert_eq!( + x_group_add.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Minus, binary(literal("2"), Star, literal("3"))), + Slash, + literal("2") + ) + ); + assert_eq!( + x_group_sub.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Plus, binary(literal("2"), Star, literal("3"))), + Slash, + literal("2") + ) + ); + assert_eq!( + x_group_mul.invert(&mut symbol_table).unwrap(), + *binary( + binary(decl_unit(), Slash, literal("3")), + Slash, + literal("2") + ) + ); + assert_eq!( + x_group_div.invert(&mut symbol_table).unwrap(), + *binary(binary(decl_unit(), Star, literal("3")), Slash, literal("2")) + ); + } + + #[test] + fn test_multiple_decl_units() { + /*let add_two = binary(decl_unit(), Plus, decl_unit()); + + let mut symbol_table = SymbolTable::new(); + assert_eq!( + add_two.invert(&mut symbol_table).unwrap(), + *binary(decl_unit(), Slash, literal("2")) + );*/ + } +} diff --git a/kalk/src/lexer.rs b/kalk/src/lexer.rs index 40d6d19..2904bda 100644 --- a/kalk/src/lexer.rs +++ b/kalk/src/lexer.rs @@ -16,8 +16,8 @@ pub enum TokenKind { Equals, Exclamation, - Deg, - Rad, + UnitKeyword, + ToKeyword, Pipe, OpenCeil, @@ -170,8 +170,8 @@ impl<'a> Lexer<'a> { } let kind = match value.as_ref() { - "deg" | "°" => TokenKind::Deg, - "rad" => TokenKind::Rad, + "unit" => TokenKind::UnitKeyword, + "to" => TokenKind::ToKeyword, _ => TokenKind::Identifier, }; diff --git a/kalk/src/lib.rs b/kalk/src/lib.rs index 0481264..79e3497 100644 --- a/kalk/src/lib.rs +++ b/kalk/src/lib.rs @@ -1,5 +1,6 @@ pub mod ast; mod interpreter; +mod inverter; mod lexer; pub mod parser; mod prelude; diff --git a/kalk/src/parser.rs b/kalk/src/parser.rs index 64c276b..6c94e0f 100644 --- a/kalk/src/parser.rs +++ b/kalk/src/parser.rs @@ -6,33 +6,50 @@ use crate::{ }; use rug::Float; +pub const DECL_UNIT: &'static str = ".u"; +pub const DEFAULT_ANGLE_UNIT: &'static str = "rad"; + /// Struct containing the current state of the parser. It stores user-defined functions and variables. /// # Examples /// ``` /// use kalk::parser; /// let mut parser_context = parser::Context::new(); /// let precision = 53; -/// assert_eq!(parser::eval(&mut parser_context, "5*3", precision).unwrap().unwrap(), 15); +/// let (result, unit) = parser::eval(&mut parser_context, "5*3", precision).unwrap().unwrap(); +/// assert_eq!(result, 15); /// ``` pub struct Context { tokens: Vec, pos: usize, symbol_table: SymbolTable, - angle_unit: Unit, + angle_unit: String, + /// This is true whenever the parser is currently parsing a unit declaration. + /// It is necessary to keep track of this in order to know when to find (figure out) units that haven't been defined yet. + /// Unit names are instead treated as variables. + parsing_unit_decl: bool, + /// When a unit declaration is being parsed, this value will be set + /// whenever a unit in the expression is found. Eg. unit a = 3b, it will be set to Some("b") + unit_decl_base_unit: Option, } impl Context { pub fn new() -> Self { - Context { + let mut context = Self { tokens: Vec::new(), pos: 0, symbol_table: SymbolTable::new(), - angle_unit: Unit::Radians, - } + angle_unit: DEFAULT_ANGLE_UNIT.into(), + parsing_unit_decl: false, + unit_decl_base_unit: None, + }; + + parse(&mut context, crate::prelude::INIT).unwrap(); + + context } - pub fn set_angle_unit(mut self, unit: Unit) -> Self { - self.angle_unit = unit; + pub fn set_angle_unit(mut self, unit: &str) -> Self { + self.angle_unit = unit.into(); self } @@ -44,13 +61,6 @@ impl Default for Context { } } -/// Mathematical unit used in calculations. -#[derive(Debug, Clone, PartialEq)] -pub enum Unit { - Radians, - Degrees, -} - /// Error that occured during parsing or evaluation. #[derive(Debug, Clone, PartialEq)] pub enum CalcError { @@ -61,6 +71,7 @@ pub enum CalcError { UnexpectedToken(TokenKind), UndefinedFn(String), UndefinedVar(String), + UnableToInvert(String), Unknown, } @@ -71,7 +82,7 @@ pub fn eval( context: &mut Context, input: &str, precision: u32, -) -> Result, CalcError> { +) -> Result, CalcError> { let statements = parse(context, input)?; let mut interpreter = @@ -85,6 +96,8 @@ pub fn eval( pub fn parse(context: &mut Context, input: &str) -> Result, CalcError> { context.tokens = Lexer::lex(input); context.pos = 0; + context.parsing_unit_decl = false; + context.unit_decl_base_unit = None; let mut statements: Vec = Vec::new(); while !is_at_end(context) { @@ -105,6 +118,8 @@ fn parse_stmt(context: &mut Context) -> Result { TokenKind::OpenParenthesis => parse_identifier_stmt(context)?, _ => Stmt::Expr(Box::new(parse_expr(context)?)), }); + } else if match_token(context, TokenKind::UnitKeyword) { + return parse_unit_decl_stmt(context); } Ok(Stmt::Expr(Box::new(parse_expr(context)?))) @@ -136,9 +151,7 @@ fn parse_identifier_stmt(context: &mut Context) -> Result { // Insert the function declaration into the symbol table during parsing // so that the parser can find out if particular functions exist. - context - .symbol_table - .insert(&format!("{}()", identifier), fn_decl.clone()); + context.symbol_table.insert(fn_decl.clone()); return Ok(fn_decl); } @@ -160,8 +173,54 @@ fn parse_var_decl_stmt(context: &mut Context) -> Result { Ok(Stmt::VarDecl(identifier.value, Box::new(expr))) } +fn parse_unit_decl_stmt(context: &mut Context) -> Result { + advance(context); // Unit keyword + let identifier = advance(context).clone(); + consume(context, TokenKind::Equals)?; + + // Parse the mut definition + context.unit_decl_base_unit = None; + context.parsing_unit_decl = true; + let def = parse_expr(context)?; + context.parsing_unit_decl = false; + + let base_unit = if let Some(base_unit) = &context.unit_decl_base_unit { + base_unit.clone() + } else { + return Err(CalcError::InvalidUnit); + }; + + // Automatically create a second unit decl with the expression inverted. + // This will turn eg. unit a = 3b, into unit b = a/3 + // This is so that you only have to define `a`, and it will figure out the formula for `b` since it is used in the formula for `a`. + let stmt_inv = Stmt::UnitDecl( + base_unit.clone(), + identifier.value.clone(), + Box::new(def.invert(&mut context.symbol_table)?), + ); + let stmt = Stmt::UnitDecl(identifier.value, base_unit, Box::new(def)); + + context.symbol_table.insert(stmt.clone()); + context.symbol_table.insert(stmt_inv); + + Ok(stmt) +} + fn parse_expr(context: &mut Context) -> Result { - Ok(parse_sum(context)?) + Ok(parse_to(context)?) +} + +fn parse_to(context: &mut Context) -> Result { + let left = parse_sum(context)?; + + if match_token(context, TokenKind::ToKeyword) { + let op = advance(context).kind.clone(); + let right = Expr::Var(advance(context).value.clone()); // Parse this as a variable for now. + + return Ok(Expr::Binary(Box::new(left), op, Box::new(right))); + } + + Ok(left) } fn parse_sum(context: &mut Context) -> Result { @@ -179,7 +238,7 @@ fn parse_sum(context: &mut Context) -> Result { } fn parse_factor(context: &mut Context) -> Result { - let mut left = parse_unary(context)?; + let mut left = parse_unit(context)?; while match_token(context, TokenKind::Star) || match_token(context, TokenKind::Slash) @@ -192,13 +251,27 @@ fn parse_factor(context: &mut Context) -> Result { _ => advance(context).kind.clone(), }; - let right = parse_unary(context)?; + let right = parse_unit(context)?; left = Expr::Binary(Box::new(left), op, Box::new(right)); } Ok(left) } +fn parse_unit(context: &mut Context) -> Result { + let expr = parse_unary(context)?; + let peek = &peek(&context).value; + + if match_token(context, TokenKind::Identifier) && context.symbol_table.contains_unit(&peek) { + return Ok(Expr::Unit( + advance(context).value.to_string(), + Box::new(expr), + )); + } + + Ok(expr) +} + fn parse_unary(context: &mut Context) -> Result { if match_token(context, TokenKind::Minus) { let op = advance(context).kind.clone(); @@ -240,11 +313,7 @@ fn parse_primary(context: &mut Context) -> Result { _ => Expr::Literal(advance(context).value.clone()), }; - if !is_at_end(context) && peek(context).kind.is_unit() { - Ok(Expr::Unit(Box::new(expr), advance(context).kind.clone())) - } else { - Ok(expr) - } + Ok(expr) } fn parse_group(context: &mut Context) -> Result { @@ -301,6 +370,9 @@ fn parse_identifier(context: &mut Context) -> Result { // Eg. x if context.symbol_table.contains_var(&identifier.value) { Ok(Expr::Var(identifier.value)) + } else if context.parsing_unit_decl { + context.unit_decl_base_unit = Some(identifier.value); + Ok(Expr::Var(DECL_UNIT.into())) } else { let mut chars = identifier.value.chars(); let mut left = Expr::Var(chars.next().unwrap().to_string()); @@ -319,19 +391,19 @@ fn parse_identifier(context: &mut Context) -> Result { } } -fn peek(context: &mut Context) -> &Token { +fn peek(context: &Context) -> &Token { &context.tokens[context.pos] } -fn peek_next(context: &mut Context) -> &Token { +fn peek_next(context: &Context) -> &Token { &context.tokens[context.pos + 1] } -fn previous(context: &mut Context) -> &Token { +fn previous(context: &Context) -> &Token { &context.tokens[context.pos - 1] } -fn match_token(context: &mut Context, kind: TokenKind) -> bool { +fn match_token(context: &Context, kind: TokenKind) -> bool { if is_at_end(context) { return false; } @@ -352,7 +424,7 @@ fn consume(context: &mut Context, kind: TokenKind) -> Result<&Token, CalcError> Err(CalcError::UnexpectedToken(kind)) } -fn is_at_end(context: &mut Context) -> bool { +fn is_at_end(context: &Context) -> bool { context.pos >= context.tokens.len() || peek(context).kind == TokenKind::EOF } @@ -361,10 +433,10 @@ mod tests { use super::*; use crate::lexer::{Token, TokenKind::*}; use crate::test_helpers::*; - use test_case::test_case; fn parse_with_context(context: &mut Context, tokens: Vec) -> Result { context.tokens = tokens; + context.pos = 0; parse_stmt(context) } @@ -372,6 +444,7 @@ mod tests { fn parse(tokens: Vec) -> Result { let mut context = Context::new(); context.tokens = tokens; + context.pos = 0; parse_stmt(&mut context) } @@ -399,6 +472,7 @@ mod tests { token(Slash, ""), token(Literal, "5"), token(ClosedParenthesis, ""), + token(EOF, ""), ]; assert_eq!( @@ -431,6 +505,7 @@ mod tests { token(Literal, "4"), token(Plus, ""), token(Literal, "5"), + token(EOF, ""), ]; assert_eq!( @@ -451,18 +526,18 @@ mod tests { ); } - #[test_case(Deg)] - #[test_case(Rad)] - fn test_unary(angle_unit: TokenKind) { - let tokens = vec![ - token(Minus, ""), - token(Literal, "1"), - token(angle_unit.clone(), ""), - ]; + #[test] + fn test_unit() { + let tokens = vec![token(Literal, "1"), token(Identifier, "a")]; + + let mut context = Context::new(); + context + .symbol_table + .insert(unit_decl("a", "b", var(super::DECL_UNIT))); assert_eq!( - parse(tokens).unwrap(), - Stmt::Expr(unary(Minus, Box::new(Expr::Unit(literal("1"), angle_unit)))) + parse_with_context(&mut context, tokens).unwrap(), + Stmt::Expr(unit("a", literal("1"))) ); } @@ -474,6 +549,7 @@ mod tests { token(Literal, "1"), token(Plus, ""), token(Literal, "2"), + token(EOF, ""), ]; assert_eq!( @@ -493,6 +569,7 @@ mod tests { token(Literal, "1"), token(Plus, ""), token(Literal, "2"), + token(EOF, ""), ]; assert_eq!( @@ -516,15 +593,17 @@ mod tests { token(ClosedParenthesis, ""), token(Plus, ""), token(Literal, "3"), + token(EOF, ""), ]; let mut context = Context::new(); // Add the function to the symbol table first, in order to prevent errors. - context.symbol_table.set( - "f()", - Stmt::FnDecl(String::from("f"), vec![String::from("x")], literal("1")), - ); + context.symbol_table.set(Stmt::FnDecl( + String::from("f"), + vec![String::from("x")], + literal("1"), + )); assert_eq!( parse_with_context(&mut context, tokens).unwrap(), diff --git a/kalk/src/prelude.rs b/kalk/src/prelude.rs index 2b6d8a6..7b159b1 100644 --- a/kalk/src/prelude.rs +++ b/kalk/src/prelude.rs @@ -1,6 +1,10 @@ +use crate::ast::Expr; +use crate::interpreter; use rug::Float; use FuncType::*; +pub const INIT: &'static str = "unit deg = (rad*180)/pi"; + pub const CONSTANTS: phf::Map<&'static str, &'static str> = phf::phf_map! { "pi" => "3.14159265", "π" => "3.14159265", @@ -11,56 +15,55 @@ pub const CONSTANTS: phf::Map<&'static str, &'static str> = phf::phf_map! { "ϕ" => "1.61803398", }; -use crate::parser::Unit; use funcs::*; -pub const UNARY_FUNCS: phf::Map<&'static str, UnaryFuncInfo> = phf::phf_map! { - "cos" => UnaryFuncInfo(cos, Trig), - "cosec" => UnaryFuncInfo(cosec, Trig), - "cosech" => UnaryFuncInfo(cosech, Trig), - "cosh" => UnaryFuncInfo(cosh, Trig), - "cot" => UnaryFuncInfo(cot, Trig), - "coth" => UnaryFuncInfo(coth, Trig), - "sec" => UnaryFuncInfo(sec, Trig), - "sech" => UnaryFuncInfo(sech, Trig), - "sin" => UnaryFuncInfo(sin, Trig), - "sinh" => UnaryFuncInfo(sinh, Trig), - "tan" => UnaryFuncInfo(tan, Trig), - "tanh" => UnaryFuncInfo(tanh, Trig), +pub const UNARY_FUNCS: phf::Map<&'static str, (UnaryFuncInfo, &'static str)> = phf::phf_map! { + "cos" => (UnaryFuncInfo(cos, Trig), ""), + "cosec" => (UnaryFuncInfo(cosec, Trig), ""), + "cosech" => (UnaryFuncInfo(cosech, Trig), ""), + "cosh" => (UnaryFuncInfo(cosh, Trig), ""), + "cot" => (UnaryFuncInfo(cot, Trig), ""), + "coth" => (UnaryFuncInfo(coth, Trig), ""), + "sec" => (UnaryFuncInfo(sec, Trig), ""), + "sech" => (UnaryFuncInfo(sech, Trig), ""), + "sin" => (UnaryFuncInfo(sin, Trig), ""), + "sinh" => (UnaryFuncInfo(sinh, Trig), ""), + "tan" => (UnaryFuncInfo(tan, Trig), ""), + "tanh" => (UnaryFuncInfo(tanh, Trig), ""), - "acos" => UnaryFuncInfo(acos, InverseTrig), - "acosec" => UnaryFuncInfo(acosec, InverseTrig), - "acosech" => UnaryFuncInfo(acosech, InverseTrig), - "acosh" => UnaryFuncInfo(acosh, InverseTrig), - "acot" => UnaryFuncInfo(acot, InverseTrig), - "acoth" => UnaryFuncInfo(acoth, InverseTrig), - "asec" => UnaryFuncInfo(asec, InverseTrig), - "asech" => UnaryFuncInfo(asech, InverseTrig), - "asin" => UnaryFuncInfo(asin, InverseTrig), - "asinh" => UnaryFuncInfo(asinh, InverseTrig), - "atan" => UnaryFuncInfo(atan, InverseTrig), - "atanh" => UnaryFuncInfo(atanh, InverseTrig), + "acos" => (UnaryFuncInfo(acos, InverseTrig), "rad"), + "acosec" => (UnaryFuncInfo(acosec, InverseTrig), "rad"), + "acosech" => (UnaryFuncInfo(acosech, InverseTrig), "rad"), + "acosh" => (UnaryFuncInfo(acosh, InverseTrig), "rad"), + "acot" => (UnaryFuncInfo(acot, InverseTrig), "rad"), + "acoth" => (UnaryFuncInfo(acoth, InverseTrig), "rad"), + "asec" => (UnaryFuncInfo(asec, InverseTrig), "rad"), + "asech" => (UnaryFuncInfo(asech, InverseTrig), "rad"), + "asin" => (UnaryFuncInfo(asin, InverseTrig), "rad"), + "asinh" => (UnaryFuncInfo(asinh, InverseTrig), "rad"), + "atan" => (UnaryFuncInfo(atan, InverseTrig), "rad"), + "atanh" => (UnaryFuncInfo(atanh, InverseTrig), "rad"), - "abs" => UnaryFuncInfo(abs, Other), - "cbrt" => UnaryFuncInfo(cbrt, Other), - "ceil" => UnaryFuncInfo(ceil, Other), - "exp" => UnaryFuncInfo(exp, Other), - "floor" => UnaryFuncInfo(floor, Other), - "frac" => UnaryFuncInfo(frac, Other), - "gamma" => UnaryFuncInfo(gamma, Other), - "Γ" => UnaryFuncInfo(gamma, Other), - "log" => UnaryFuncInfo(log, Other), - "ln" => UnaryFuncInfo(ln, Other), - "round" => UnaryFuncInfo(round, Other), - "sqrt" => UnaryFuncInfo(sqrt, Other), - "√" => UnaryFuncInfo(sqrt, Other), - "trunc" => UnaryFuncInfo(trunc, Other), + "abs" => (UnaryFuncInfo(abs, Other), ""), + "cbrt" => (UnaryFuncInfo(cbrt, Other), ""), + "ceil" => (UnaryFuncInfo(ceil, Other), ""), + "exp" => (UnaryFuncInfo(exp, Other), ""), + "floor" => (UnaryFuncInfo(floor, Other), ""), + "frac" => (UnaryFuncInfo(frac, Other), ""), + "gamma" => (UnaryFuncInfo(gamma, Other), ""), + "Γ" => (UnaryFuncInfo(gamma, Other), ""), + "log" => (UnaryFuncInfo(log, Other), ""), + "ln" => (UnaryFuncInfo(ln, Other), ""), + "round" => (UnaryFuncInfo(round, Other), ""), + "sqrt" => (UnaryFuncInfo(sqrt, Other), ""), + "√" => (UnaryFuncInfo(sqrt, Other), ""), + "trunc" => (UnaryFuncInfo(trunc, Other), ""), }; -pub const BINARY_FUNCS: phf::Map<&'static str, BinaryFuncInfo> = phf::phf_map! { - "max" => BinaryFuncInfo(max, Other), - "min" => BinaryFuncInfo(min, Other), - "hyp" => BinaryFuncInfo(hyp, Other), - "log" => BinaryFuncInfo(logx, Other), - "sqrt" => BinaryFuncInfo(nth_sqrt, Other), +pub const BINARY_FUNCS: phf::Map<&'static str, (BinaryFuncInfo, &'static str)> = phf::phf_map! { + "max" => (BinaryFuncInfo(max, Other), ""), + "min" => (BinaryFuncInfo(min, Other), ""), + "hyp" => (BinaryFuncInfo(hyp, Other), ""), + "log" => (BinaryFuncInfo(logx, Other), ""), + "root" => (BinaryFuncInfo(nth_root, Other), ""), }; enum FuncType { @@ -75,57 +78,88 @@ pub struct UnaryFuncInfo(fn(Float) -> Float, FuncType); pub struct BinaryFuncInfo(fn(Float, Float) -> Float, FuncType); impl UnaryFuncInfo { - fn call(&self, x: Float, angle_unit: &Unit) -> Float { + fn call(&self, context: &mut interpreter::Context, x: Float, angle_unit: &str) -> Float { let func = self.0; match self.1 { - FuncType::Trig => func(from_angle_unit(x, angle_unit)), - FuncType::InverseTrig => to_angle_unit(func(x), angle_unit), + FuncType::Trig => func(from_angle_unit(context, x, angle_unit)), + FuncType::InverseTrig => to_angle_unit(context, func(x), angle_unit), FuncType::Other => func(x), } } } impl BinaryFuncInfo { - fn call(&self, x: Float, y: Float, angle_unit: &Unit) -> Float { + fn call( + &self, + context: &mut interpreter::Context, + x: Float, + y: Float, + angle_unit: &str, + ) -> Float { let func = self.0; match self.1 { FuncType::Trig => func( - from_angle_unit(x, angle_unit), - from_angle_unit(y, angle_unit), + from_angle_unit(context, x, angle_unit), + from_angle_unit(context, y, angle_unit), ), - FuncType::InverseTrig => to_angle_unit(func(x, y), angle_unit), + FuncType::InverseTrig => to_angle_unit(context, func(x, y), angle_unit), FuncType::Other => func(x, y), } } } -pub fn call_unary_func(name: &str, x: Float, angle_unit: &Unit) -> Option { - if let Some(func_info) = UNARY_FUNCS.get(name) { - Some(func_info.call(x, &angle_unit)) +pub fn call_unary_func( + context: &mut interpreter::Context, + name: &str, + x: Float, + angle_unit: &str, +) -> Option<(Float, String)> { + if let Some((func_info, func_unit)) = UNARY_FUNCS.get(name) { + Some(( + func_info.call(context, x, &angle_unit), + func_unit.to_string(), + )) } else { None } } -pub fn call_binary_func(name: &str, x: Float, y: Float, angle_unit: &Unit) -> Option { - if let Some(func_info) = BINARY_FUNCS.get(name) { - Some(func_info.call(x, y, angle_unit)) +pub fn call_binary_func( + context: &mut interpreter::Context, + name: &str, + x: Float, + y: Float, + angle_unit: &str, +) -> Option<(Float, String)> { + if let Some((func_info, func_unit)) = BINARY_FUNCS.get(name) { + Some(( + func_info.call(context, x, y, angle_unit), + func_unit.to_string(), + )) } else { None } } -fn to_angle_unit(x: Float, angle_unit: &Unit) -> Float { +fn to_angle_unit(context: &mut interpreter::Context, x: Float, angle_unit: &str) -> Float { match angle_unit { - Unit::Radians => x, - Unit::Degrees => special_funcs::to_degrees(x), + "rad" => x, + _ => { + interpreter::convert_unit(context, &Expr::Literal(x.to_string()), "rad", angle_unit) + .unwrap() + .0 + } } } -fn from_angle_unit(x: Float, angle_unit: &Unit) -> Float { +fn from_angle_unit(context: &mut interpreter::Context, x: Float, angle_unit: &str) -> Float { match angle_unit { - Unit::Radians => x, - Unit::Degrees => special_funcs::to_radians(x), + "rad" => x, + _ => { + interpreter::convert_unit(context, &Expr::Literal(x.to_string()), angle_unit, "rad") + .unwrap() + .0 + } } } @@ -135,14 +169,6 @@ pub mod special_funcs { pub fn factorial(x: Float) -> Float { ((x + 1) as Float).gamma() } - - pub fn to_degrees(x: Float) -> Float { - Float::with_val(53, x.to_f64().to_degrees()) - } - - pub fn to_radians(x: Float) -> Float { - Float::with_val(53, x.to_f64().to_radians()) - } } mod funcs { @@ -297,7 +323,7 @@ mod funcs { x.sqrt() } - pub fn nth_sqrt(x: Float, n: Float) -> Float { + pub fn nth_root(x: Float, n: Float) -> Float { x.pow(Float::with_val(1, 1) / n) } diff --git a/kalk/src/symbol_table.rs b/kalk/src/symbol_table.rs index a8df17e..aead965 100644 --- a/kalk/src/symbol_table.rs +++ b/kalk/src/symbol_table.rs @@ -1,40 +1,81 @@ use crate::{ast::Stmt, prelude}; use std::collections::HashMap; +#[derive(Debug)] pub struct SymbolTable { - hashmap: HashMap, + pub(crate) hashmap: HashMap, + pub(crate) unit_types: HashMap, } impl SymbolTable { pub fn new() -> Self { SymbolTable { hashmap: HashMap::new(), + unit_types: HashMap::new(), } } - pub fn insert(&mut self, key: &str, value: Stmt) { - self.hashmap.insert(key.into(), value); + pub fn insert(&mut self, value: Stmt) -> &mut Self { + match &value { + Stmt::VarDecl(identifier, _) => { + self.hashmap.insert(format!("var.{}", identifier), value); + } + Stmt::UnitDecl(identifier, to_unit, _) => { + self.unit_types.insert(identifier.to_string(), ()); + self.unit_types.insert(to_unit.to_string(), ()); + self.hashmap + .insert(format!("unit.{}.{}", identifier, to_unit), value); + } + Stmt::FnDecl(identifier, _, _) => { + self.hashmap.insert(format!("fn.{}", identifier), value); + } + _ => panic!("Can only insert VarDecl, UnitDecl and FnDecl into symbol table."), + } + + self } - pub fn get(&self, key: &str) -> Option<&Stmt> { - self.hashmap.get(key) + pub fn get_var(&self, key: &str) -> Option<&Stmt> { + self.hashmap.get(&format!("var.{}", key)) } - pub fn set(&mut self, key: &str, value: Stmt) { - if let Some(stmt) = self.hashmap.get_mut(key) { + pub fn get_unit(&self, key: &str, to_unit: &str) -> Option<&Stmt> { + self.hashmap.get(&format!("unit.{}.{}", key, to_unit)) + } + + pub fn get_fn(&self, key: &str) -> Option<&Stmt> { + self.hashmap.get(&format!("fn.{}", key)) + } + + pub fn set(&mut self, value: Stmt) { + let existing_item = match &value { + Stmt::VarDecl(identifier, _) => self.hashmap.get_mut(&format!("var.{}", identifier)), + Stmt::UnitDecl(identifier, to_unit, _) => self + .hashmap + .get_mut(&format!("unit.{}.{}", identifier, to_unit)), + Stmt::FnDecl(identifier, _, _) => self.hashmap.get_mut(&format!("fn.{}", identifier)), + _ => panic!("Can only set VarDecl, UnitDecl and FnDecl in symbol table."), + }; + + if let Some(stmt) = existing_item { *stmt = value; } else { - self.insert(key, value); + self.insert(value); } } pub fn contains_var(&self, identifier: &str) -> bool { - prelude::CONSTANTS.contains_key(identifier) || self.hashmap.contains_key(identifier) + prelude::CONSTANTS.contains_key(identifier) + || self.hashmap.contains_key(&format!("var.{}", identifier)) + } + + pub fn contains_unit(&self, identifier: &str) -> bool { + self.unit_types.contains_key(identifier) } pub fn contains_fn(&self, identifier: &str) -> bool { prelude::UNARY_FUNCS.contains_key(identifier) || prelude::UNARY_FUNCS.contains_key(identifier) - || self.hashmap.contains_key(&format!("{}()", identifier)) + || self.hashmap.contains_key(&format!("fn.{}", identifier)) } } diff --git a/kalk/src/test_helpers.rs b/kalk/src/test_helpers.rs index 89d8d9f..5aa202d 100644 --- a/kalk/src/test_helpers.rs +++ b/kalk/src/test_helpers.rs @@ -36,6 +36,10 @@ pub fn group(expr: Box) -> Box { Box::new(Expr::Group(expr)) } +pub fn unit(identifier: &str, expr: Box) -> Box { + Box::new(Expr::Unit(identifier.into(), expr)) +} + pub fn var_decl(identifier: &str, value: Box) -> Stmt { Stmt::VarDecl(identifier.into(), value) } @@ -43,3 +47,7 @@ pub fn var_decl(identifier: &str, value: Box) -> Stmt { pub fn fn_decl(identifier: &str, parameters: Vec, value: Box) -> Stmt { Stmt::FnDecl(identifier.into(), parameters, value) } + +pub fn unit_decl(unit: &str, base_unit: &str, expr: Box) -> Stmt { + Stmt::UnitDecl(unit.into(), base_unit.into(), expr) +} diff --git a/kalk_cli/src/main.rs b/kalk_cli/src/main.rs index 5440718..0721602 100644 --- a/kalk_cli/src/main.rs +++ b/kalk_cli/src/main.rs @@ -2,13 +2,12 @@ mod output; mod repl; use kalk::parser; -use kalk::parser::Unit; use std::env; use std::fs::File; use std::io::Read; fn main() { - let mut parser_context = parser::Context::new().set_angle_unit(get_angle_unit()); + let mut parser_context = parser::Context::new().set_angle_unit(&get_angle_unit()); // Command line argument input, execute it and exit. let mut args = env::args().skip(1); @@ -26,10 +25,14 @@ fn main() { // The indentation... Will have to do something more scalable in the future. println!( " --= kalk help =-\n +[kalk help] + kalk [OPTIONS] [INPUT] -h, --help : show this -i : load a file with predefined functions/variables + +[Environment variables] +ANGLE_UNIT=(deg/rad) : Sets the default unit used for trigonometric functions. " ); return; @@ -64,16 +67,10 @@ kalk [OPTIONS] [INPUT] } } -fn get_angle_unit() -> Unit { +fn get_angle_unit() -> String { if let Ok(angle_unit_var) = env::var("ANGLE_UNIT") { - match angle_unit_var.as_ref() { - "radians" => Unit::Radians, - "degrees" => Unit::Degrees, - _ => { - panic!("Unexpected angle unit: {}.", angle_unit_var); - } - } + angle_unit_var } else { - Unit::Radians + String::from("rad") } } diff --git a/kalk_cli/src/output.rs b/kalk_cli/src/output.rs index ba7260a..fa5d9eb 100644 --- a/kalk_cli/src/output.rs +++ b/kalk_cli/src/output.rs @@ -3,7 +3,7 @@ use kalk::parser::{self, CalcError, CalcError::*}; pub fn eval(parser: &mut parser::Context, input: &str) { match parser::eval(parser, input, 53) { - Ok(Some(result)) => { + Ok(Some((result, unit))) => { let (_, digits, exp_option) = result.to_sign_string_exp(10, None); let exp = if let Some(exp) = exp_option { exp } else { 0 }; @@ -36,9 +36,9 @@ pub fn eval(parser: &mut parser::Context, input: &str) { }; if use_sci_notation { - println!("{}{}*10^{}", sign, num, exp - 1); + println!("{}{}*10^{} {}", sign, num, exp - 1, unit); } else { - println!("{}{}", sign, num); + println!("{}{} {}", sign, num, unit); } } } @@ -94,6 +94,7 @@ fn print_calc_err(err: CalcError) { InvalidOperator => format!("Invalid operator."), InvalidUnit => format!("Invalid unit."), UnexpectedToken(kind) => format!("Unexpected token: '{:?}'.", kind), + UnableToInvert(msg) => format!("Unable to invert: {}", msg), UndefinedFn(name) => format!("Undefined function: '{}'.", name), UndefinedVar(name) => format!("Undefined variable: '{}'.", name), Unknown => format!("Unknown error."),