Vector comprehensions

This commit is contained in:
PaddiM8 2022-01-16 00:33:26 +01:00
parent 20f61efa7f
commit 67ef28bd7f
8 changed files with 271 additions and 11 deletions

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
ast::{ConditionalPiece, Expr, Identifier, Stmt}, ast::{ConditionalPiece, Expr, Identifier, RangedVar, Stmt},
inverter, inverter,
lexer::TokenKind, lexer::TokenKind,
parser::{self, CalcError}, parser::{self, CalcError},
@ -13,9 +13,12 @@ pub(crate) struct Context<'a> {
current_function_parameters: Option<Vec<String>>, current_function_parameters: Option<Vec<String>>,
equation_variable: Option<String>, equation_variable: Option<String>,
in_integral: bool, in_integral: bool,
in_sum_prod: bool,
in_unit_decl: bool, in_unit_decl: bool,
in_conditional: bool, in_conditional: bool,
in_equation: bool, in_equation: bool,
in_comprehension: bool,
comprehension_vars: Option<Vec<RangedVar>>,
} }
pub(crate) fn analyse_stmt( pub(crate) fn analyse_stmt(
@ -28,9 +31,12 @@ pub(crate) fn analyse_stmt(
current_function_parameters: None, current_function_parameters: None,
equation_variable: None, equation_variable: None,
in_integral: false, in_integral: false,
in_sum_prod: false,
in_unit_decl: false, in_unit_decl: false,
in_conditional: false, in_conditional: false,
in_equation: false, in_equation: false,
in_comprehension: false,
comprehension_vars: None,
}; };
Ok(match statement { Ok(match statement {
@ -256,6 +262,7 @@ fn analyse_expr(context: &mut Context, expr: Expr) -> Result<Expr, CalcError> {
Expr::Indexer(Box::new(analyse_expr(context, *value)?), analysed_indexes) Expr::Indexer(Box::new(analyse_expr(context, *value)?), analysed_indexes)
} }
Expr::Comprehension(left, right, vars) => Expr::Comprehension(left, right, vars),
}) })
} }
@ -270,8 +277,8 @@ fn analyse_binary<'a>(
context.in_conditional = true; context.in_conditional = true;
} }
let result = match (&left, &op) { let result = match (&left, &op, &right) {
(_, TokenKind::Equals) if !context.in_conditional => { (_, TokenKind::Equals, _) if !context.in_conditional => {
// Equation // Equation
context.in_equation = true; context.in_equation = true;
let left = analyse_expr(context, left)?; let left = analyse_expr(context, left)?;
@ -319,20 +326,78 @@ fn analyse_binary<'a>(
return Ok(inverted); return Ok(inverted);
} }
(Expr::Var(_), TokenKind::Star) => { (Expr::Var(_), TokenKind::Star, _) => {
if let Expr::Var(identifier) = left { if let Expr::Var(identifier) = left {
analyse_var(context, identifier, Some(right), None) analyse_var(context, identifier, Some(right), None)
} else { } else {
unreachable!() unreachable!()
} }
} }
(Expr::Var(_), TokenKind::Power) => { (Expr::Var(_), TokenKind::Power, _) => {
if let Expr::Var(identifier) = left { if let Expr::Var(identifier) = left {
analyse_var(context, identifier, None, Some(right)) analyse_var(context, identifier, None, Some(right))
} else { } else {
unreachable!() unreachable!()
} }
} }
(_, TokenKind::Colon, _) => {
context.in_comprehension = true;
context.in_conditional = true;
context.comprehension_vars = Some(Vec::new());
let mut conditions = vec![right];
let mut has_comma = false;
while let Expr::Binary(_, TokenKind::Comma, _) = conditions.last().unwrap() {
has_comma = true;
if let Expr::Binary(left_condition, _, right_condition) = conditions.pop().unwrap()
{
conditions.push(analyse_expr(context, *left_condition.to_owned())?);
conditions.push(analyse_expr(context, *right_condition.to_owned())?);
}
}
if !has_comma {
let analysed_condition = analyse_expr(context, conditions.pop().unwrap())?;
conditions.push(analysed_condition);
}
context.in_comprehension = false;
context.in_conditional = false;
let left = analyse_expr(context, left)?;
let result = Expr::Comprehension(
Box::new(left),
conditions,
context.comprehension_vars.take().unwrap(),
);
Ok(result)
}
(
Expr::Var(_),
TokenKind::GreaterThan
| TokenKind::LessThan
| TokenKind::GreaterOrEquals
| TokenKind::LessOrEquals,
_,
) => analyse_comparison_with_var(context, left, op, right),
(
_,
TokenKind::GreaterThan
| TokenKind::LessThan
| TokenKind::GreaterOrEquals
| TokenKind::LessOrEquals,
Expr::Var(_),
) => {
let inv_op = match op {
TokenKind::GreaterThan => TokenKind::LessThan,
TokenKind::LessThan => TokenKind::GreaterThan,
TokenKind::GreaterOrEquals => TokenKind::LessOrEquals,
TokenKind::LessOrEquals => TokenKind::GreaterOrEquals,
_ => unreachable!(),
};
analyse_comparison_with_var(context, right, inv_op, left)
}
_ => Ok(Expr::Binary( _ => Ok(Expr::Binary(
Box::new(analyse_expr(context, left)?), Box::new(analyse_expr(context, left)?),
op, op,
@ -345,6 +410,69 @@ fn analyse_binary<'a>(
result result
} }
fn analyse_comparison_with_var(
context: &mut Context,
var: Expr,
op: TokenKind,
right: Expr,
) -> Result<Expr, CalcError> {
let right = analyse_expr(context, right)?;
if context.comprehension_vars.is_none() {
return Ok(Expr::Binary(
Box::new(analyse_expr(context, var)?),
op,
Box::new(right),
));
}
// Make sure any comprehension variables
// are added to context.comprehension_variables.
let analysed_var = analyse_expr(context, var)?;
let var_name = if let Expr::Var(identifier) = &analysed_var {
&identifier.pure_name
} else {
unreachable!("Expected Expr::Var");
};
let vars = context.comprehension_vars.as_mut().unwrap();
for ranged_var in vars {
if &ranged_var.name == var_name {
match op {
TokenKind::GreaterThan => {
ranged_var.min = Expr::Binary(
Box::new(right.clone()),
TokenKind::Plus,
Box::new(Expr::Literal(1f64)),
);
}
TokenKind::LessThan => {
ranged_var.max = right.clone();
}
TokenKind::GreaterOrEquals => {
ranged_var.min = right.clone();
}
TokenKind::LessOrEquals => {
ranged_var.max = Expr::Binary(
Box::new(right.clone()),
TokenKind::Plus,
Box::new(Expr::Literal(1f64)),
);
}
_ => unreachable!(),
}
break;
}
}
Ok(Expr::Binary(
Box::new(Expr::Literal(0f64)),
TokenKind::Equals,
Box::new(Expr::Literal(0f64)),
))
}
fn analyse_var( fn analyse_var(
context: &mut Context, context: &mut Context,
identifier: Identifier, identifier: Identifier,
@ -380,7 +508,15 @@ fn analyse_var(
None None
}; };
if context.symbol_table.contains_var(&identifier.pure_name) let is_comprehension_var = if let Some(vars) = &context.comprehension_vars {
vars.iter().any(|x| x.name == identifier.pure_name)
} else {
false
};
if is_comprehension_var {
with_adjacent(Expr::Var(identifier), adjacent_factor, adjacent_exponent)
} else if context.symbol_table.contains_var(&identifier.pure_name)
|| (identifier.pure_name.len() == 1 && !context.in_equation) || (identifier.pure_name.len() == 1 && !context.in_equation)
{ {
with_adjacent( with_adjacent(
@ -473,8 +609,13 @@ fn build_fn_call(
context.in_integral = true; context.in_integral = true;
} }
let is_sum_prod = identifier.pure_name == "sum" || identifier.pure_name == "prod";
if is_sum_prod {
context.in_sum_prod = true;
}
// Don't perform equation solving on special functions // Don't perform equation solving on special functions
if is_integral || identifier.pure_name == "sum" || identifier.pure_name == "prod" { if is_integral || is_sum_prod {
context.in_equation = false; context.in_equation = false;
} }
@ -508,6 +649,10 @@ fn build_fn_call(
context.in_integral = false; context.in_integral = false;
} }
if is_sum_prod {
context.in_sum_prod = false;
}
return Ok(Expr::FnCall(identifier, arguments)); return Ok(Expr::FnCall(identifier, arguments));
} }
@ -615,9 +760,24 @@ fn build_var(context: &mut Context, name: &str) -> Expr {
} }
} }
if context.in_equation && !context.symbol_table.contains_var(name) { if context.in_sum_prod && name == "n" {
return Expr::Var(Identifier::from_full_name(name));
}
let var_exists = context.symbol_table.contains_var(name);
if context.in_equation && !var_exists {
context.equation_variable = Some(name.to_string()); context.equation_variable = Some(name.to_string());
} }
if context.in_comprehension && !var_exists {
if let Some(vars) = context.comprehension_vars.as_mut() {
vars.push(RangedVar {
name: name.to_string(),
max: Expr::Literal(0f64),
min: Expr::Literal(0f64),
});
}
}
Expr::Var(Identifier::from_full_name(name)) Expr::Var(Identifier::from_full_name(name))
} }

View File

@ -24,6 +24,7 @@ pub enum Expr {
Vector(Vec<Expr>), Vector(Vec<Expr>),
Matrix(Vec<Vec<Expr>>), Matrix(Vec<Vec<Expr>>),
Indexer(Box<Expr>, Vec<Expr>), Indexer(Box<Expr>, Vec<Expr>),
Comprehension(Box<Expr>, Vec<Expr>, Vec<RangedVar>),
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -32,6 +33,13 @@ pub struct ConditionalPiece {
pub condition: Expr, pub condition: Expr,
} }
#[derive(Debug, Clone, PartialEq)]
pub struct RangedVar {
pub name: String,
pub max: Expr,
pub min: Expr,
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Identifier { pub struct Identifier {
pub full_name: String, pub full_name: String,

View File

@ -42,6 +42,7 @@ mod tests {
} }
#[test_case("basics")] #[test_case("basics")]
#[test_case("comprehensions")]
#[test_case("derivation")] #[test_case("derivation")]
#[test_case("functions")] #[test_case("functions")]
#[test_case("groups")] #[test_case("groups")]

View File

@ -1,5 +1,5 @@
use crate::ast::Identifier;
use crate::ast::{Expr, Stmt}; use crate::ast::{Expr, Stmt};
use crate::ast::{Identifier, RangedVar};
use crate::calculation_result::CalculationResult; use crate::calculation_result::CalculationResult;
use crate::kalk_value::KalkValue; use crate::kalk_value::KalkValue;
use crate::lexer::TokenKind; use crate::lexer::TokenKind;
@ -127,6 +127,9 @@ pub(crate) fn eval_expr(
Expr::Vector(values) => eval_vector(context, values), Expr::Vector(values) => eval_vector(context, values),
Expr::Matrix(rows) => eval_matrix(context, rows), Expr::Matrix(rows) => eval_matrix(context, rows),
Expr::Indexer(var, indexes) => eval_indexer(context, var, indexes, unit), Expr::Indexer(var, indexes) => eval_indexer(context, var, indexes, unit),
Expr::Comprehension(left, conditions, vars) => Ok(KalkValue::Vector(eval_comprehension(
context, left, conditions, vars,
)?)),
} }
} }
@ -634,6 +637,53 @@ fn eval_indexer(
} }
} }
fn eval_comprehension(
context: &mut Context,
left: &Expr,
conditions: &[Expr],
vars: &[RangedVar],
) -> Result<Vec<KalkValue>, CalcError> {
if vars.len() != conditions.len() {
return Err(CalcError::InvalidComprehension(String::from("Expected a new variable to be introduced for every condition (conditions are comma separated).")));
}
let condition = conditions.first().unwrap();
let var = vars.first().unwrap();
context.symbol_table.insert(Stmt::VarDecl(
Identifier::from_full_name(&var.name),
Box::new(Expr::Literal(0f64)),
));
let min = eval_expr(context, &var.min, "")?.to_f64() as i32;
let max = eval_expr(context, &var.max, "")?.to_f64() as i32;
let mut values = Vec::new();
for i in min..max {
context.symbol_table.set(Stmt::VarDecl(
Identifier::from_full_name(&var.name),
Box::new(Expr::Literal(i as f64)),
));
if conditions.len() > 1 {
let x = eval_comprehension(context, left, &conditions[1..], &vars[1..])?;
for value in x {
values.push(value);
}
}
let condition = eval_expr(context, condition, "")?;
if let KalkValue::Boolean(boolean) = condition {
if boolean && vars.len() == 1 {
values.push(eval_expr(context, left, "")?);
}
}
}
context.symbol_table.get_and_remove_var(&var.name);
Ok(values)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -90,6 +90,9 @@ fn invert(
Expr::Vector(_) => Err(CalcError::UnableToInvert(String::from("Vector"))), Expr::Vector(_) => Err(CalcError::UnableToInvert(String::from("Vector"))),
Expr::Matrix(_) => Err(CalcError::UnableToInvert(String::from("Matrix"))), Expr::Matrix(_) => Err(CalcError::UnableToInvert(String::from("Matrix"))),
Expr::Indexer(_, _) => Err(CalcError::UnableToInvert(String::from("Inverter"))), Expr::Indexer(_, _) => Err(CalcError::UnableToInvert(String::from("Inverter"))),
Expr::Comprehension(_, _, _) => {
Err(CalcError::UnableToInvert(String::from("Comprehension")))
}
} }
} }
@ -396,6 +399,7 @@ pub fn contains_var(symbol_table: &SymbolTable, expr: &Expr, var_name: &str) ->
.iter() .iter()
.any(|row| row.iter().any(|x| contains_var(symbol_table, x, var_name))), .any(|row| row.iter().any(|x| contains_var(symbol_table, x, var_name))),
Expr::Indexer(_, _) => false, Expr::Indexer(_, _) => false,
Expr::Comprehension(_, _, _) => false,
} }
} }

View File

@ -45,6 +45,7 @@ pub enum TokenKind {
OpenBrace, OpenBrace,
ClosedBrace, ClosedBrace,
Comma, Comma,
Colon,
Semicolon, Semicolon,
Newline, Newline,
@ -149,6 +150,7 @@ impl<'a> Lexer<'a> {
'∧' => build(TokenKind::And, "", span), '∧' => build(TokenKind::And, "", span),
'' => build(TokenKind::Or, "", span), '' => build(TokenKind::Or, "", span),
',' => build(TokenKind::Comma, "", span), ',' => build(TokenKind::Comma, "", span),
':' => build(TokenKind::Colon, "", span),
';' => build(TokenKind::Semicolon, "", span), ';' => build(TokenKind::Semicolon, "", span),
'\n' => build(TokenKind::Newline, "", span), '\n' => build(TokenKind::Newline, "", span),
'%' => build(TokenKind::Percent, "", span), '%' => build(TokenKind::Percent, "", span),
@ -388,7 +390,7 @@ fn is_valid_identifier(c: Option<&char>) -> bool {
match c { match c {
'+' | '-' | '/' | '*' | '%' | '^' | '!' | '(' | ')' | '=' | '.' | ',' | ';' | '|' '+' | '-' | '/' | '*' | '%' | '^' | '!' | '(' | ')' | '=' | '.' | ',' | ';' | '|'
| '⌊' | '⌋' | '⌈' | '⌉' | '[' | ']' | '{' | '}' | 'π' | '√' | 'τ' | 'ϕ' | 'Γ' | '<' | '⌊' | '⌋' | '⌈' | '⌉' | '[' | ']' | '{' | '}' | 'π' | '√' | 'τ' | 'ϕ' | 'Γ' | '<'
| '>' | '≠' | '≥' | '≤' | '×' | '÷' | '⋅' | '⟦' | '⟧' | '∧' | '' | '\n' => { | '>' | '≠' | '≥' | '≤' | '×' | '÷' | '⋅' | '⟦' | '⟧' | '∧' | '' | ':' | '\n' => {
false false
} }
_ => !c.is_digit(10) || is_superscript(c) || is_subscript(c), _ => !c.is_digit(10) || is_superscript(c) || is_subscript(c),

View File

@ -97,6 +97,7 @@ pub enum CalcError {
IncorrectAmountOfIndexes(usize, usize), IncorrectAmountOfIndexes(usize, usize),
ItemOfIndexDoesNotExist(Vec<usize>), ItemOfIndexDoesNotExist(Vec<usize>),
InconsistentColumnWidths, InconsistentColumnWidths,
InvalidComprehension(String),
InvalidNumberLiteral(String), InvalidNumberLiteral(String),
InvalidOperator, InvalidOperator,
InvalidUnit, InvalidUnit,
@ -131,6 +132,7 @@ impl ToString for CalcError {
), ),
CalcError::ItemOfIndexDoesNotExist(indexes) => format!("Item of index ⟦{}⟧ does not exist.", indexes.iter().map(|x| x.to_string()).collect::<Vec<String>>().join(", ")), CalcError::ItemOfIndexDoesNotExist(indexes) => format!("Item of index ⟦{}⟧ does not exist.", indexes.iter().map(|x| x.to_string()).collect::<Vec<String>>().join(", ")),
CalcError::InconsistentColumnWidths => format!("Inconsistent column widths. Matrix columns must be the same size."), CalcError::InconsistentColumnWidths => format!("Inconsistent column widths. Matrix columns must be the same size."),
CalcError::InvalidComprehension(x) => format!("Invalid comprehension: {}", x),
CalcError::InvalidNumberLiteral(x) => format!("Invalid number literal: '{}'.", x), CalcError::InvalidNumberLiteral(x) => format!("Invalid number literal: '{}'.", x),
CalcError::InvalidOperator => format!("Invalid operator."), CalcError::InvalidOperator => format!("Invalid operator."),
CalcError::InvalidUnit => format!("Invalid unit."), CalcError::InvalidUnit => format!("Invalid unit."),
@ -310,6 +312,32 @@ fn parse_expr(context: &mut Context) -> Result<Expr, CalcError> {
Ok(parse_or(context)?) Ok(parse_or(context)?)
} }
fn parse_comprehension(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_or(context)?;
if match_token(context, TokenKind::Colon) {
let op = advance(context).kind;
skip_newlines(context);
let right = Box::new(parse_comprehension_comma(context)?);
return Ok(Expr::Binary(Box::new(left), op, right));
}
Ok(left)
}
fn parse_comprehension_comma(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_or(context)?;
if match_token(context, TokenKind::Comma) {
let op = advance(context).kind;
skip_newlines(context);
let right = Box::new(parse_comprehension_comma(context)?);
return Ok(Expr::Binary(Box::new(left), op, right));
}
Ok(left)
}
fn parse_or(context: &mut Context) -> Result<Expr, CalcError> { fn parse_or(context: &mut Context) -> Result<Expr, CalcError> {
let left = parse_and(context)?; let left = parse_and(context)?;
@ -545,7 +573,12 @@ fn parse_vector(context: &mut Context) -> Result<Expr, CalcError> {
skip_newlines(context); skip_newlines(context);
} }
let mut rows = vec![vec![parse_expr(context)?]]; let first_expr = if kind == TokenKind::OpenBracket {
parse_comprehension(context)?
} else {
parse_expr(context)?
};
let mut rows = vec![vec![first_expr]];
let mut column_count = None; let mut column_count = None;
let mut items_in_row = 1; let mut items_in_row = 1;
while match_token(context, TokenKind::Comma) while match_token(context, TokenKind::Comma)

View File

@ -0,0 +1,2 @@
[x : 0 ≤ x and 5 > x] = (0, 1, 2, 3, 4) and
[(x, y) : x > 0 and x <= 3, y > 0 and y <= 2] = [(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2)]