Fixed comparison expressions not working in functions

This commit is contained in:
bakk 2022-04-24 23:57:46 +02:00
parent 4d6ef5e8d7
commit 23edc83577
5 changed files with 55 additions and 32 deletions

View File

@ -280,9 +280,10 @@ fn analyse_binary(
let left = analyse_expr(context, left)?; let left = analyse_expr(context, left)?;
let right = analyse_expr(context, right)?; let right = analyse_expr(context, right)?;
// If it has already been set to false manually somewhere else, // If it has already been set to false manually somewhere else
// or if there is no equation variable,
// abort and analyse as a comparison instead. // abort and analyse as a comparison instead.
if !context.in_equation { if !context.in_equation || context.equation_variable.is_none() {
context.in_conditional = true; context.in_conditional = true;
let result = analyse_binary(context, left, op, right); let result = analyse_binary(context, left, op, right);
context.in_conditional = previous_in_conditional; context.in_conditional = previous_in_conditional;
@ -295,11 +296,7 @@ fn analyse_binary(
let var_name = if let Some(var_name) = &context.equation_variable { let var_name = if let Some(var_name) = &context.equation_variable {
var_name var_name
} else { } else {
context.in_conditional = true; unreachable!()
let result = analyse_binary(context, left, op, right);
context.in_conditional = previous_in_conditional;
return result;
}; };
let identifier = Identifier::from_full_name(var_name); let identifier = Identifier::from_full_name(var_name);
context.equation_variable = None; context.equation_variable = None;
@ -513,6 +510,19 @@ fn analyse_var(
} }
if context.in_equation { if context.in_equation {
let is_parameter = if let (Some(fn_name), Some(parameters)) = (
&context.current_function_name,
&context.current_function_parameters,
) {
parameters.contains(&identifier.full_name)
|| parameters.contains(
&Identifier::parameter_from_name(&identifier.full_name, fn_name).full_name,
)
} else {
false
};
if !is_parameter {
context.equation_variable = Some(identifier.full_name.clone()); context.equation_variable = Some(identifier.full_name.clone());
return with_adjacent( return with_adjacent(
build_var(context, &identifier.full_name), build_var(context, &identifier.full_name),
@ -520,6 +530,7 @@ fn analyse_var(
adjacent_exponent, adjacent_exponent,
); );
} }
}
let mut identifier_without_dx: Vec<char> = identifier.full_name.chars().collect(); let mut identifier_without_dx: Vec<char> = identifier.full_name.chars().collect();
let last_char = identifier_without_dx.pop().unwrap_or_default(); let last_char = identifier_without_dx.pop().unwrap_or_default();
@ -651,7 +662,7 @@ fn build_split_up_vars(
left = Expr::Binary(Box::new(left), TokenKind::Star, Box::new(right)) left = Expr::Binary(Box::new(left), TokenKind::Star, Box::new(right))
} }
Ok(left) with_adjacent(left, adjacent_factor, adjacent_exponent)
} }
fn build_var(context: &mut Context, name: &str) -> Expr { fn build_var(context: &mut Context, name: &str) -> Expr {

View File

@ -33,6 +33,7 @@ pub enum KalkError {
UnableToParseExpression, UnableToParseExpression,
UnrecognizedBase, UnrecognizedBase,
Unknown, Unknown,
WasStmt(crate::ast::Stmt),
} }
impl ToString for KalkError { impl ToString for KalkError {
@ -77,7 +78,7 @@ impl ToString for KalkError {
KalkError::UnableToSolveEquation => String::from("Unable to solve equation."), KalkError::UnableToSolveEquation => String::from("Unable to solve equation."),
KalkError::UnableToOverrideConstant(name) => format!("Unable to override constant: '{}'.", name), KalkError::UnableToOverrideConstant(name) => format!("Unable to override constant: '{}'.", name),
KalkError::UnrecognizedBase => String::from("Unrecognized base."), KalkError::UnrecognizedBase => String::from("Unrecognized base."),
KalkError::Unknown => String::from("Unknown error."), KalkError::Unknown | KalkError::WasStmt(_) => String::from("Unknown error."),
} }
} }
} }

View File

@ -41,6 +41,7 @@ mod tests {
} }
} }
#[test_case("ambiguities/comparison_in_function")]
#[test_case("basics")] #[test_case("basics")]
#[test_case("comparisons")] #[test_case("comparisons")]
#[test_case("comprehensions")] #[test_case("comprehensions")]

View File

@ -127,7 +127,11 @@ pub fn parse(context: &mut Context, input: &str) -> Result<Vec<Stmt>, KalkError>
let mut statements: Vec<Stmt> = Vec::new(); let mut statements: Vec<Stmt> = Vec::new();
while !is_at_end(context) { while !is_at_end(context) {
let parsed = parse_stmt(context)?; let parsed = match parse_stmt(context) {
Ok(stmt) => stmt,
Err(KalkError::WasStmt(stmt)) => stmt,
Err(err) => return Err(err),
};
let symbol_table = context.symbol_table.get_mut(); let symbol_table = context.symbol_table.get_mut();
let analysed = analysis::analyse_stmt(symbol_table, parsed)?; let analysed = analysis::analyse_stmt(symbol_table, parsed)?;
statements.push(analysed); statements.push(analysed);
@ -300,26 +304,26 @@ fn parse_comparison(context: &mut Context) -> Result<Expr, KalkError> {
|| match_token(context, TokenKind::GreaterOrEquals) || match_token(context, TokenKind::GreaterOrEquals)
|| match_token(context, TokenKind::LessOrEquals) || match_token(context, TokenKind::LessOrEquals)
{ {
let op = peek(context).kind; let op = advance(context).kind;
advance(context);
let is_fn_decl = if let Some((identifier, parameters)) = analysis::is_fn_decl(&left) { if let Some((identifier, parameters)) = analysis::is_fn_decl(&left) {
context.symbol_table.get_mut().set(Stmt::FnDecl( context.symbol_table.get_mut().set(Stmt::FnDecl(
identifier, identifier.clone(),
parameters, parameters.clone(),
Box::new(Expr::Literal(0f64)), Box::new(Expr::Literal(1f64)),
)); ));
let right = if match_token(context, TokenKind::OpenBrace) {
true
} else {
false
};
let right = if op == TokenKind::Equals && match_token(context, TokenKind::OpenBrace) {
parse_piecewise(context)? parse_piecewise(context)?
} else { } else {
parse_comparison(context)? parse_expr(context)?
}; };
let fn_decl = Stmt::FnDecl(identifier, parameters, Box::new(right));
// Hack to return a statement...
return Err(KalkError::WasStmt(fn_decl));
};
let right = parse_comparison(context)?;
left = match right { left = match right {
Expr::Binary( Expr::Binary(
@ -331,7 +335,7 @@ fn parse_comparison(context: &mut Context) -> Result<Expr, KalkError> {
| TokenKind::GreaterOrEquals | TokenKind::GreaterOrEquals
| TokenKind::LessOrEquals), | TokenKind::LessOrEquals),
inner_right, inner_right,
) if !is_fn_decl => Expr::Binary( ) => Expr::Binary(
Box::new(Expr::Binary( Box::new(Expr::Binary(
Box::new(left), Box::new(left),
op, op,
@ -730,7 +734,11 @@ mod tests {
context.tokens = tokens; context.tokens = tokens;
context.pos = 0; context.pos = 0;
let parsed = parse_stmt(&mut context)?; let parsed = match parse_stmt(&mut context) {
Ok(stmt) => stmt,
Err(KalkError::WasStmt(stmt)) => stmt,
Err(err) => return Err(err),
};
let symbol_table = context.symbol_table.get_mut(); let symbol_table = context.symbol_table.get_mut();
analysis::analyse_stmt(symbol_table, parsed) analysis::analyse_stmt(symbol_table, parsed)
} }

View File

@ -0,0 +1,2 @@
f(a, b, c) = (a * b = c)
f(2, 2, 4)