Fixed comparison expressions not working in functions

This commit is contained in:
bakk 2022-04-24 23:57:46 +02:00
parent 4d6ef5e8d7
commit 23edc83577
5 changed files with 55 additions and 32 deletions

View File

@ -280,9 +280,10 @@ fn analyse_binary(
let left = analyse_expr(context, left)?;
let right = analyse_expr(context, right)?;
// If it has already been set to false manually somewhere else,
// If it has already been set to false manually somewhere else
// or if there is no equation variable,
// abort and analyse as a comparison instead.
if !context.in_equation {
if !context.in_equation || context.equation_variable.is_none() {
context.in_conditional = true;
let result = analyse_binary(context, left, op, right);
context.in_conditional = previous_in_conditional;
@ -295,11 +296,7 @@ fn analyse_binary(
let var_name = if let Some(var_name) = &context.equation_variable {
var_name
} else {
context.in_conditional = true;
let result = analyse_binary(context, left, op, right);
context.in_conditional = previous_in_conditional;
return result;
unreachable!()
};
let identifier = Identifier::from_full_name(var_name);
context.equation_variable = None;
@ -513,12 +510,26 @@ fn analyse_var(
}
if context.in_equation {
context.equation_variable = Some(identifier.full_name.clone());
return with_adjacent(
build_var(context, &identifier.full_name),
adjacent_factor,
adjacent_exponent,
);
let is_parameter = if let (Some(fn_name), Some(parameters)) = (
&context.current_function_name,
&context.current_function_parameters,
) {
parameters.contains(&identifier.full_name)
|| parameters.contains(
&Identifier::parameter_from_name(&identifier.full_name, fn_name).full_name,
)
} else {
false
};
if !is_parameter {
context.equation_variable = Some(identifier.full_name.clone());
return with_adjacent(
build_var(context, &identifier.full_name),
adjacent_factor,
adjacent_exponent,
);
}
}
let mut identifier_without_dx: Vec<char> = identifier.full_name.chars().collect();
@ -651,7 +662,7 @@ fn build_split_up_vars(
left = Expr::Binary(Box::new(left), TokenKind::Star, Box::new(right))
}
Ok(left)
with_adjacent(left, adjacent_factor, adjacent_exponent)
}
fn build_var(context: &mut Context, name: &str) -> Expr {

View File

@ -33,6 +33,7 @@ pub enum KalkError {
UnableToParseExpression,
UnrecognizedBase,
Unknown,
WasStmt(crate::ast::Stmt),
}
impl ToString for KalkError {
@ -77,7 +78,7 @@ impl ToString for KalkError {
KalkError::UnableToSolveEquation => String::from("Unable to solve equation."),
KalkError::UnableToOverrideConstant(name) => format!("Unable to override constant: '{}'.", name),
KalkError::UnrecognizedBase => String::from("Unrecognized base."),
KalkError::Unknown => String::from("Unknown error."),
KalkError::Unknown | KalkError::WasStmt(_) => String::from("Unknown error."),
}
}
}

View File

@ -41,6 +41,7 @@ mod tests {
}
}
#[test_case("ambiguities/comparison_in_function")]
#[test_case("basics")]
#[test_case("comparisons")]
#[test_case("comprehensions")]

View File

@ -127,7 +127,11 @@ pub fn parse(context: &mut Context, input: &str) -> Result<Vec<Stmt>, KalkError>
let mut statements: Vec<Stmt> = Vec::new();
while !is_at_end(context) {
let parsed = parse_stmt(context)?;
let parsed = match parse_stmt(context) {
Ok(stmt) => stmt,
Err(KalkError::WasStmt(stmt)) => stmt,
Err(err) => return Err(err),
};
let symbol_table = context.symbol_table.get_mut();
let analysed = analysis::analyse_stmt(symbol_table, parsed)?;
statements.push(analysed);
@ -300,26 +304,26 @@ fn parse_comparison(context: &mut Context) -> Result<Expr, KalkError> {
|| match_token(context, TokenKind::GreaterOrEquals)
|| match_token(context, TokenKind::LessOrEquals)
{
let op = peek(context).kind;
advance(context);
let op = advance(context).kind;
let is_fn_decl = if let Some((identifier, parameters)) = analysis::is_fn_decl(&left) {
if let Some((identifier, parameters)) = analysis::is_fn_decl(&left) {
context.symbol_table.get_mut().set(Stmt::FnDecl(
identifier,
parameters,
Box::new(Expr::Literal(0f64)),
identifier.clone(),
parameters.clone(),
Box::new(Expr::Literal(1f64)),
));
let right = if match_token(context, TokenKind::OpenBrace) {
parse_piecewise(context)?
} else {
parse_expr(context)?
};
let fn_decl = Stmt::FnDecl(identifier, parameters, Box::new(right));
true
} else {
false
// Hack to return a statement...
return Err(KalkError::WasStmt(fn_decl));
};
let right = if op == TokenKind::Equals && match_token(context, TokenKind::OpenBrace) {
parse_piecewise(context)?
} else {
parse_comparison(context)?
};
let right = parse_comparison(context)?;
left = match right {
Expr::Binary(
@ -331,7 +335,7 @@ fn parse_comparison(context: &mut Context) -> Result<Expr, KalkError> {
| TokenKind::GreaterOrEquals
| TokenKind::LessOrEquals),
inner_right,
) if !is_fn_decl => Expr::Binary(
) => Expr::Binary(
Box::new(Expr::Binary(
Box::new(left),
op,
@ -730,7 +734,11 @@ mod tests {
context.tokens = tokens;
context.pos = 0;
let parsed = parse_stmt(&mut context)?;
let parsed = match parse_stmt(&mut context) {
Ok(stmt) => stmt,
Err(KalkError::WasStmt(stmt)) => stmt,
Err(err) => return Err(err),
};
let symbol_table = context.symbol_table.get_mut();
analysis::analyse_stmt(symbol_table, parsed)
}

View File

@ -0,0 +1,2 @@
f(a, b, c) = (a * b = c)
f(2, 2, 4)