Vector comprehensions

This commit is contained in:
PaddiM8 2022-01-16 00:33:26 +01:00
parent 20f61efa7f
commit 67ef28bd7f
8 changed files with 271 additions and 11 deletions

View File

@ -1,5 +1,5 @@
use crate::{
ast::{ConditionalPiece, Expr, Identifier, Stmt},
ast::{ConditionalPiece, Expr, Identifier, RangedVar, Stmt},
inverter,
lexer::TokenKind,
parser::{self, CalcError},
@ -13,9 +13,12 @@ pub(crate) struct Context<'a> {
current_function_parameters: Option<Vec<String>>,
equation_variable: Option<String>,
in_integral: bool,
in_sum_prod: bool,
in_unit_decl: bool,
in_conditional: bool,
in_equation: bool,
in_comprehension: bool,
comprehension_vars: Option<Vec<RangedVar>>,
}
pub(crate) fn analyse_stmt(
@ -28,9 +31,12 @@ pub(crate) fn analyse_stmt(
current_function_parameters: None,
equation_variable: None,
in_integral: false,
in_sum_prod: false,
in_unit_decl: false,
in_conditional: false,
in_equation: false,
in_comprehension: false,
comprehension_vars: None,
};
Ok(match statement {
@ -256,6 +262,7 @@ fn analyse_expr(context: &mut Context, expr: Expr) -> Result<Expr, CalcError> {
Expr::Indexer(Box::new(analyse_expr(context, *value)?), analysed_indexes)
}
Expr::Comprehension(left, right, vars) => Expr::Comprehension(left, right, vars),
})
}
@ -270,8 +277,8 @@ fn analyse_binary<'a>(
context.in_conditional = true;
}
let result = match (&left, &op) {
(_, TokenKind::Equals) if !context.in_conditional => {
let result = match (&left, &op, &right) {
(_, TokenKind::Equals, _) if !context.in_conditional => {
// Equation
context.in_equation = true;
let left = analyse_expr(context, left)?;
@ -319,20 +326,78 @@ fn analyse_binary<'a>(
return Ok(inverted);
}
(Expr::Var(_), TokenKind::Star) => {
(Expr::Var(_), TokenKind::Star, _) => {
if let Expr::Var(identifier) = left {
analyse_var(context, identifier, Some(right), None)
} else {
unreachable!()
}
}
(Expr::Var(_), TokenKind::Power) => {
(Expr::Var(_), TokenKind::Power, _) => {
if let Expr::Var(identifier) = left {
analyse_var(context, identifier, None, Some(right))
} else {
unreachable!()
}
}
(_, TokenKind::Colon, _) => {
context.in_comprehension = true;
context.in_conditional = true;
context.comprehension_vars = Some(Vec::new());
let mut conditions = vec![right];
let mut has_comma = false;
while let Expr::Binary(_, TokenKind::Comma, _) = conditions.last().unwrap() {
has_comma = true;
if let Expr::Binary(left_condition, _, right_condition) = conditions.pop().unwrap()
{
conditions.push(analyse_expr(context, *left_condition.to_owned())?);
conditions.push(analyse_expr(context, *right_condition.to_owned())?);
}
}
if !has_comma {
let analysed_condition = analyse_expr(context, conditions.pop().unwrap())?;
conditions.push(analysed_condition);
}
context.in_comprehension = false;
context.in_conditional = false;
let left = analyse_expr(context, left)?;
let result = Expr::Comprehension(
Box::new(left),
conditions,
context.comprehension_vars.take().unwrap(),
);
Ok(result)
}
(
Expr::Var(_),
TokenKind::GreaterThan
| TokenKind::LessThan
| TokenKind::GreaterOrEquals
| TokenKind::LessOrEquals,
_,
) => analyse_comparison_with_var(context, left, op, right),
(
_,
TokenKind::GreaterThan
| TokenKind::LessThan
| TokenKind::GreaterOrEquals
| TokenKind::LessOrEquals,
Expr::Var(_),
) => {
let inv_op = match op {
TokenKind::GreaterThan => TokenKind::LessThan,
TokenKind::LessThan => TokenKind::GreaterThan,
TokenKind::GreaterOrEquals => TokenKind::LessOrEquals,
TokenKind::LessOrEquals => TokenKind::GreaterOrEquals,
_ => unreachable!(),
};
analyse_comparison_with_var(context, right, inv_op, left)
}
_ => Ok(Expr::Binary(
Box::new(analyse_expr(context, left)?),
op,
@ -345,6 +410,69 @@ fn analyse_binary<'a>(
result
}
fn analyse_comparison_with_var(
context: &mut Context,
var: Expr,
op: TokenKind,
right: Expr,
) -> Result<Expr, CalcError> {
let right = analyse_expr(context, right)?;
if context.comprehension_vars.is_none() {
return Ok(Expr::Binary(
Box::new(analyse_expr(context, var)?),
op,
Box::new(right),
));
}
// Make sure any comprehension variables
// are added to context.comprehension_variables.
let analysed_var = analyse_expr(context, var)?;
let var_name = if let Expr::Var(identifier) = &analysed_var {
&identifier.pure_name
} else {
unreachable!("Expected Expr::Var");
};
let vars = context.comprehension_vars.as_mut().unwrap();
for ranged_var in vars {
if &ranged_var.name == var_name {
match op {
TokenKind::GreaterThan => {
ranged_var.min = Expr::Binary(
Box::new(right.clone()),
TokenKind::Plus,
Box::new(Expr::Literal(1f64)),
);
}
TokenKind::LessThan => {
ranged_var.max = right.clone();
}
TokenKind::GreaterOrEquals => {
ranged_var.min = right.clone();
}
TokenKind::LessOrEquals => {
ranged_var.max = Expr::Binary(
Box::new(right.clone()),
TokenKind::Plus,
Box::new(Expr::Literal(1f64)),
);
}
_ => unreachable!(),
}
break;
}
}
Ok(Expr::Binary(
Box::new(Expr::Literal(0f64)),
TokenKind::Equals,
Box::new(Expr::Literal(0f64)),
))
}
fn analyse_var(
context: &mut Context,
identifier: Identifier,
@ -380,7 +508,15 @@ fn analyse_var(
None
};
if context.symbol_table.contains_var(&identifier.pure_name)
let is_comprehension_var = if let Some(vars) = &context.comprehension_vars {
vars.iter().any(|x| x.name == identifier.pure_name)
} else {
false
};
if is_comprehension_var {
with_adjacent(Expr::Var(identifier), adjacent_factor, adjacent_exponent)
} else if context.symbol_table.contains_var(&identifier.pure_name)
|| (identifier.pure_name.len() == 1 && !context.in_equation)
{
with_adjacent(
@ -473,8 +609,13 @@ fn build_fn_call(
context.in_integral = true;
}
let is_sum_prod = identifier.pure_name == "sum" || identifier.pure_name == "prod";
if is_sum_prod {
context.in_sum_prod = true;
}
// Don't perform equation solving on special functions
if is_integral || identifier.pure_name == "sum" || identifier.pure_name == "prod" {
if is_integral || is_sum_prod {
context.in_equation = false;
}
@ -508,6 +649,10 @@ fn build_fn_call(
context.in_integral = false;
}
if is_sum_prod {
context.in_sum_prod = false;
}
return Ok(Expr::FnCall(identifier, arguments));
}
@ -615,9 +760,24 @@ fn build_var(context: &mut Context, name: &str) -> Expr {
}
}
if context.in_equation && !context.symbol_table.contains_var(name) {
if context.in_sum_prod && name == "n" {
return Expr::Var(Identifier::from_full_name(name));
}
let var_exists = context.symbol_table.contains_var(name);
if context.in_equation && !var_exists {
context.equation_variable = Some(name.to_string());
}
if context.in_comprehension && !var_exists {
if let Some(vars) = context.comprehension_vars.as_mut() {
vars.push(RangedVar {
name: name.to_string(),
max: Expr::Literal(0f64),
min: Expr::Literal(0f64),
});
}
}
Expr::Var(Identifier::from_full_name(name))
}

View File

@ -24,6 +24,7 @@ pub enum Expr {
Vector(Vec<Expr>),
Matrix(Vec<Vec<Expr>>),
Indexer(Box<Expr>, Vec<Expr>),
Comprehension(Box<Expr>, Vec<Expr>, Vec<RangedVar>),
}
#[derive(Debug, Clone, PartialEq)]
@ -32,6 +33,13 @@ pub struct ConditionalPiece {
pub condition: Expr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RangedVar {
pub name: String,
pub max: Expr,
pub min: Expr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Identifier {
pub full_name: String,

View File

@ -42,6 +42,7 @@ mod tests {
}
#[test_case("basics")]
#[test_case("comprehensions")]
#[test_case("derivation")]
#[test_case("functions")]
#[test_case("groups")]

View File

@ -1,5 +1,5 @@
use crate::ast::Identifier;
use crate::ast::{Expr, Stmt};
use crate::ast::{Identifier, RangedVar};
use crate::calculation_result::CalculationResult;
use crate::kalk_value::KalkValue;
use crate::lexer::TokenKind;
@ -127,6 +127,9 @@ pub(crate) fn eval_expr(
Expr::Vector(values) => eval_vector(context, values),
Expr::Matrix(rows) => eval_matrix(context, rows),
Expr::Indexer(var, indexes) => eval_indexer(context, var, indexes, unit),
Expr::Comprehension(left, conditions, vars) => Ok(KalkValue::Vector(eval_comprehension(
context, left, conditions, vars,
)?)),
}
}
@ -634,6 +637,53 @@ fn eval_indexer(
}
}
fn eval_comprehension(
context: &mut Context,
left: &Expr,
conditions: &[Expr],
vars: &[RangedVar],
) -> Result<Vec<KalkValue>, CalcError> {
if vars.len() != conditions.len() {
return Err(CalcError::InvalidComprehension(String::from("Expected a new variable to be introduced for every condition (conditions are comma separated).")));
}
let condition = conditions.first().unwrap();
let var = vars.first().unwrap();
context.symbol_table.insert(Stmt::VarDecl(
Identifier::from_full_name(&var.name),
Box::new(Expr::Literal(0f64)),
));
let min = eval_expr(context, &var.min, "")?.to_f64() as i32;
let max = eval_expr(context, &var.max, "")?.to_f64() as i32;
let mut values = Vec::new();
for i in min..max {
context.symbol_table.set(Stmt::VarDecl(
Identifier::from_full_name(&var.name),
Box::new(Expr::Literal(i as f64)),
));
if conditions.len() > 1 {
let x = eval_comprehension(context, left, &conditions[1..], &vars[1..])?;
for value in x {
values.push(value);
}
}
let condition = eval_expr(context, condition, "")?;
if let KalkValue::Boolean(boolean) = condition {
if boolean && vars.len() == 1 {
values.push(eval_expr(context, left, "")?);
}
}
}
context.symbol_table.get_and_remove_var(&var.name);
Ok(values)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -90,6 +90,9 @@ fn invert(
Expr::Vector(_) => Err(CalcError::UnableToInvert(String::from("Vector"))),
Expr::Matrix(_) => Err(CalcError::UnableToInvert(String::from("Matrix"))),
Expr::Indexer(_, _) => Err(CalcError::UnableToInvert(String::from("Inverter"))),
Expr::Comprehension(_, _, _) => {
Err(CalcError::UnableToInvert(String::from("Comprehension")))
}
}
}
@ -396,6 +399,7 @@ pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) ->
.iter()
.any(|row| row.iter().any(|x| contains_var(symbol_table, x, var_name))),
Expr::Indexer(_, _) => false,
Expr::Comprehension(_, _, _) => false,
}
}

View File

@ -45,6 +45,7 @@ pub enum TokenKind {
OpenBrace,
ClosedBrace,
Comma,
Colon,
Semicolon,
Newline,
@ -149,6 +150,7 @@ impl<'a> Lexer<'a> {
'∧' => build(TokenKind::And, "", span),
'' => build(TokenKind::Or, "", span),
',' => build(TokenKind::Comma, "", span),
':' => build(TokenKind::Colon, "", span),
';' => build(TokenKind::Semicolon, "", span),
'\n' => build(TokenKind::Newline, "", span),
'%' => build(TokenKind::Percent, "", span),
@ -388,7 +390,7 @@ fn is_valid_identifier(c: Option<&char>) -> bool {
match c {
'+' | '-' | '/' | '*' | '%' | '^' | '!' | '(' | ')' | '=' | '.' | ',' | ';' | '|'
| '⌊' | '⌋' | '⌈' | '⌉' | '[' | ']' | '{' | '}' | 'π' | '√' | 'τ' | 'ϕ' | 'Γ' | '<'
| '>' | '≠' | '≥' | '≤' | '×' | '÷' | '⋅' | '⟦' | '⟧' | '∧' | '' | '\n' => {
| '>' | '≠' | '≥' | '≤' | '×' | '÷' | '⋅' | '⟦' | '⟧' | '∧' | '' | ':' | '\n' => {
false
}
_ => !c.is_digit(10) || is_superscript(c) || is_subscript(c),

View File

@ -97,6 +97,7 @@ pub enum CalcError {
IncorrectAmountOfIndexes(usize, usize),
ItemOfIndexDoesNotExist(Vec<usize>),
InconsistentColumnWidths,
InvalidComprehension(String),
InvalidNumberLiteral(String),
InvalidOperator,
InvalidUnit,
@ -131,6 +132,7 @@ impl ToString for CalcError {
),
CalcError::ItemOfIndexDoesNotExist(indexes) => format!("Item of index ⟦{}⟧ does not exist.", indexes.iter().map(|x| x.to_string()).collect::<Vec<String>>().join(", ")),
CalcError::InconsistentColumnWidths => format!("Inconsistent column widths. Matrix columns must be the same size."),
CalcError::InvalidComprehension(x) => format!("Invalid comprehension: {}", x),
CalcError::InvalidNumberLiteral(x) => format!("Invalid number literal: '{}'.", x),
CalcError::InvalidOperator => format!("Invalid operator."),
CalcError::InvalidUnit => format!("Invalid unit."),
@ -310,6 +312,32 @@ fn parse_expr(context: &mut Context) -> Result<Expr, CalcError> {
Ok(parse_or(context)?)
}
fn parse_comprehension(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_or(context)?;
if match_token(context, TokenKind::Colon) {
let op = advance(context).kind;
skip_newlines(context);
let right = Box::new(parse_comprehension_comma(context)?);
return Ok(Expr::Binary(Box::new(left), op, right));
}
Ok(left)
}
fn parse_comprehension_comma(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_or(context)?;
if match_token(context, TokenKind::Comma) {
let op = advance(context).kind;
skip_newlines(context);
let right = Box::new(parse_comprehension_comma(context)?);
return Ok(Expr::Binary(Box::new(left), op, right));
}
Ok(left)
}
fn parse_or(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_and(context)?;
@ -545,7 +573,12 @@ fn parse_vector(context: &mut Context) -> Result<Expr, CalcError> {
skip_newlines(context);
}
let mut rows = vec![vec![parse_expr(context)?]];
let first_expr = if kind == TokenKind::OpenBracket {
parse_comprehension(context)?
} else {
parse_expr(context)?
};
let mut rows = vec![vec![first_expr]];
let mut column_count = None;
let mut items_in_row = 1;
while match_token(context, TokenKind::Comma)

View File

@ -0,0 +1,2 @@
[x : 0 ≤ x and 5 > x] = (0, 1, 2, 3, 4) and
[(x, y) : x > 0 and x <= 3, y > 0 and y <= 2] = [(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2)]