diff --git a/crates/nu-command/src/dataframe/eager/mod.rs b/crates/nu-command/src/dataframe/eager/mod.rs index bbeb551036..80421f77f9 100644 --- a/crates/nu-command/src/dataframe/eager/mod.rs +++ b/crates/nu-command/src/dataframe/eager/mod.rs @@ -17,6 +17,8 @@ mod rename; mod sample; mod shape; mod slice; +mod sql_context; +mod sql_expr; mod take; mod to_arrow; mod to_csv; @@ -24,6 +26,7 @@ mod to_df; mod to_nu; mod to_parquet; mod with_column; +mod with_sql; use nu_protocol::engine::StateWorkingSet; @@ -46,6 +49,8 @@ pub use rename::RenameDF; pub use sample::SampleDF; pub use shape::ShapeDF; pub use slice::SliceDF; +pub use sql_context::SQLContext; +pub use sql_expr::parse_sql_expr; pub use take::TakeDF; pub use to_arrow::ToArrow; pub use to_csv::ToCSV; @@ -53,6 +58,7 @@ pub use to_df::ToDataFrame; pub use to_nu::ToNu; pub use to_parquet::ToParquet; pub use with_column::WithColumn; +pub use with_sql::WithSql; pub fn add_eager_decls(working_set: &mut StateWorkingSet) { macro_rules! bind_command { @@ -91,6 +97,7 @@ pub fn add_eager_decls(working_set: &mut StateWorkingSet) { ToDataFrame, ToNu, ToParquet, - WithColumn + WithColumn, + WithSql ); } diff --git a/crates/nu-command/src/dataframe/eager/sql_context.rs b/crates/nu-command/src/dataframe/eager/sql_context.rs new file mode 100644 index 0000000000..f8c2acd153 --- /dev/null +++ b/crates/nu-command/src/dataframe/eager/sql_context.rs @@ -0,0 +1,220 @@ +use crate::dataframe::eager::sql_expr::parse_sql_expr; +use polars::error::PolarsError; +use polars::prelude::{col, DataFrame, DataType, IntoLazy, LazyFrame}; +use sqlparser::ast::{ + Expr as SqlExpr, Select, SelectItem, SetExpr, Statement, TableFactor, Value as SQLValue, +}; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser; +use std::collections::HashMap; + +#[derive(Default)] +pub struct SQLContext { + table_map: HashMap, + dialect: GenericDialect, +} + +impl SQLContext { + pub fn new() -> Self { + Self { + table_map: HashMap::new(), + dialect: GenericDialect::default(), + } + } + + pub fn register(&mut self, name: &str, df: &DataFrame) { + self.table_map.insert(name.to_owned(), df.clone().lazy()); + } + + fn execute_select(&self, select_stmt: &Select) -> Result { + // Determine involved dataframe + // Implicit join require some more work in query parsers, Explicit join are preferred for now. + let tbl = select_stmt.from.get(0).ok_or_else(|| { + PolarsError::NotFound("No table found in select statement".to_string()) + })?; + let mut alias_map = HashMap::new(); + let tbl_name = match &tbl.relation { + TableFactor::Table { name, alias, .. } => { + let tbl_name = name + .0 + .get(0) + .ok_or_else(|| { + PolarsError::NotFound("No table found in select statement".to_string()) + })? + .value + .to_string(); + if self.table_map.contains_key(&tbl_name) { + if let Some(alias) = alias { + alias_map.insert(alias.name.value.clone(), tbl_name.to_owned()); + }; + tbl_name + } else { + return Err(PolarsError::ComputeError( + format!("Table name {tbl_name} was not found").into(), + )); + } + } + // Support bare table, optional with alias for now + _ => return Err(PolarsError::ComputeError("Not implemented".into())), + }; + let df = &self.table_map[&tbl_name]; + let mut raw_projection_before_alias: HashMap = HashMap::new(); + let mut contain_wildcard = false; + // Filter Expression + let df = match select_stmt.selection.as_ref() { + Some(expr) => { + let filter_expression = parse_sql_expr(expr)?; + df.clone().filter(filter_expression) + } + None => df.clone(), + }; + // Column Projections + let projection = select_stmt + .projection + .iter() + .enumerate() + .map(|(i, select_item)| { + Ok(match select_item { + SelectItem::UnnamedExpr(expr) => { + let expr = parse_sql_expr(expr)?; + raw_projection_before_alias.insert(format!("{:?}", expr), i); + expr + } + SelectItem::ExprWithAlias { expr, alias } => { + let expr = parse_sql_expr(expr)?; + raw_projection_before_alias.insert(format!("{:?}", expr), i); + expr.alias(&alias.value) + } + SelectItem::QualifiedWildcard(_) | SelectItem::Wildcard => { + contain_wildcard = true; + col("*") + } + }) + }) + .collect::, PolarsError>>()?; + // Check for group by + // After projection since there might be number. + let group_by = select_stmt + .group_by + .iter() + .map( + |e|match e { + SqlExpr::Value(SQLValue::Number(idx, _)) => { + let idx = match idx.parse::() { + Ok(0)| Err(_) => Err( + PolarsError::ComputeError( + format!("Group By Error: Only positive number or expression are supported, got {idx}").into() + )), + Ok(idx) => Ok(idx) + }?; + Ok(projection[idx].clone()) + } + SqlExpr::Value(_) => Err( + PolarsError::ComputeError("Group By Error: Only positive number or expression are supported".into()) + ), + _ => parse_sql_expr(e) + } + ) + .collect::, PolarsError>>()?; + + let df = if group_by.is_empty() { + df.select(projection) + } else { + // check groupby and projection due to difference between SQL and polars + // Return error on wild card, shouldn't process this + if contain_wildcard { + return Err(PolarsError::ComputeError( + "Group By Error: Can't processed wildcard in groupby".into(), + )); + } + // Default polars group by will have group by columns at the front + // need some container to contain position of group by columns and its position + // at the final agg projection, check the schema for the existence of group by column + // and its projections columns, keeping the original index + let (exclude_expr, groupby_pos): (Vec<_>, Vec<_>) = group_by + .iter() + .map(|expr| raw_projection_before_alias.get(&format!("{:?}", expr))) + .enumerate() + .filter(|(_, proj_p)| proj_p.is_some()) + .map(|(gb_p, proj_p)| (*proj_p.unwrap_or(&0), (*proj_p.unwrap_or(&0), gb_p))) + .unzip(); + let (agg_projection, agg_proj_pos): (Vec<_>, Vec<_>) = projection + .iter() + .enumerate() + .filter(|(i, _)| !exclude_expr.contains(i)) + .enumerate() + .map(|(agg_pj, (proj_p, expr))| (expr.clone(), (proj_p, agg_pj + group_by.len()))) + .unzip(); + let agg_df = df.groupby(group_by).agg(agg_projection); + let mut final_proj_pos = groupby_pos + .into_iter() + .chain(agg_proj_pos.into_iter()) + .collect::>(); + + final_proj_pos.sort_by(|(proj_pa, _), (proj_pb, _)| proj_pa.cmp(proj_pb)); + let final_proj = final_proj_pos + .into_iter() + .map(|(_, shm_p)| { + col(agg_df + .clone() + // FIXME: had to do this mess to get get_index to work, not sure why. need help + .collect() + .unwrap_or_default() + .schema() + .get_index(shm_p) + .unwrap_or((&"".to_string(), &DataType::Null)) + .0) + }) + .collect::>(); + agg_df.select(final_proj) + }; + Ok(df) + } + + pub fn execute(&self, query: &str) -> Result { + let ast = Parser::parse_sql(&self.dialect, query) + .map_err(|e| PolarsError::ComputeError(format!("{:?}", e).into()))?; + if ast.len() != 1 { + Err(PolarsError::ComputeError( + "One and only one statement at a time please".into(), + )) + } else { + let ast = ast + .get(0) + .ok_or_else(|| PolarsError::NotFound("No statement found".to_string()))?; + Ok(match ast { + Statement::Query(query) => { + let rs = match &query.body { + SetExpr::Select(select_stmt) => self.execute_select(&*select_stmt)?, + _ => { + return Err(PolarsError::ComputeError( + "INSERT, UPDATE is not supported for polars".into(), + )) + } + }; + match &query.limit { + Some(SqlExpr::Value(SQLValue::Number(nrow, _))) => { + let nrow = nrow.parse().map_err(|err| { + PolarsError::ComputeError( + format!("Conversion Error: {:?}", err).into(), + ) + })?; + rs.limit(nrow) + } + None => rs, + _ => { + return Err(PolarsError::ComputeError( + "Only support number argument to LIMIT clause".into(), + )) + } + } + } + _ => { + return Err(PolarsError::ComputeError( + format!("Statement type {:?} is not supported", ast).into(), + )) + } + }) + } + } +} diff --git a/crates/nu-command/src/dataframe/eager/sql_expr.rs b/crates/nu-command/src/dataframe/eager/sql_expr.rs new file mode 100644 index 0000000000..d434acd284 --- /dev/null +++ b/crates/nu-command/src/dataframe/eager/sql_expr.rs @@ -0,0 +1,191 @@ +use polars::error::PolarsError; +use polars::prelude::{col, lit, DataType, Expr, LiteralValue, Result, TimeUnit}; + +use sqlparser::ast::{ + BinaryOperator as SQLBinaryOperator, DataType as SQLDataType, Expr as SqlExpr, + Function as SQLFunction, Value as SqlValue, WindowSpec, +}; + +fn map_sql_polars_datatype(data_type: &SQLDataType) -> Result { + Ok(match data_type { + SQLDataType::Char(_) + | SQLDataType::Varchar(_) + | SQLDataType::Uuid + | SQLDataType::Clob(_) + | SQLDataType::Text + | SQLDataType::String => DataType::Utf8, + SQLDataType::Float(_) => DataType::Float32, + SQLDataType::Real => DataType::Float32, + SQLDataType::Double => DataType::Float64, + SQLDataType::TinyInt(_) => DataType::Int8, + SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, + SQLDataType::SmallInt(_) => DataType::Int16, + SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, + SQLDataType::Int(_) => DataType::Int32, + SQLDataType::UnsignedInt(_) => DataType::UInt32, + SQLDataType::BigInt(_) => DataType::Int64, + SQLDataType::UnsignedBigInt(_) => DataType::UInt64, + + SQLDataType::Boolean => DataType::Boolean, + SQLDataType::Date => DataType::Date, + SQLDataType::Time => DataType::Time, + SQLDataType::Timestamp => DataType::Datetime(TimeUnit::Milliseconds, None), + SQLDataType::Interval => DataType::Duration(TimeUnit::Milliseconds), + SQLDataType::Array(inner_type) => { + DataType::List(Box::new(map_sql_polars_datatype(inner_type)?)) + } + _ => { + return Err(PolarsError::ComputeError( + format!( + "SQL Datatype {:?} was not supported in polars-sql yet!", + data_type + ) + .into(), + )) + } + }) +} + +fn cast_(expr: Expr, data_type: &SQLDataType) -> Result { + let polars_type = map_sql_polars_datatype(data_type)?; + Ok(expr.cast(polars_type)) +} + +fn binary_op_(left: Expr, right: Expr, op: &SQLBinaryOperator) -> Result { + Ok(match op { + SQLBinaryOperator::Plus => left + right, + SQLBinaryOperator::Minus => left - right, + SQLBinaryOperator::Multiply => left * right, + SQLBinaryOperator::Divide => left / right, + SQLBinaryOperator::Modulo => left % right, + SQLBinaryOperator::StringConcat => left.cast(DataType::Utf8) + right.cast(DataType::Utf8), + SQLBinaryOperator::Gt => left.gt(right), + SQLBinaryOperator::Lt => left.lt(right), + SQLBinaryOperator::GtEq => left.gt_eq(right), + SQLBinaryOperator::LtEq => left.lt_eq(right), + SQLBinaryOperator::Eq => left.eq(right), + SQLBinaryOperator::NotEq => left.eq(right).not(), + SQLBinaryOperator::And => left.and(right), + SQLBinaryOperator::Or => left.or(right), + SQLBinaryOperator::Xor => left.xor(right), + _ => { + return Err(PolarsError::ComputeError( + format!("SQL Operator {:?} was not supported in polars-sql yet!", op).into(), + )) + } + }) +} + +fn literal_expr(value: &SqlValue) -> Result { + Ok(match value { + SqlValue::Number(s, _) => { + // Check for existence of decimal separator dot + if s.contains('.') { + s.parse::().map(lit).map_err(|_| { + PolarsError::ComputeError(format!("Can't parse literal {:?}", s).into()) + }) + } else { + s.parse::().map(lit).map_err(|_| { + PolarsError::ComputeError(format!("Can't parse literal {:?}", s).into()) + }) + }? + } + SqlValue::SingleQuotedString(s) => lit(s.clone()), + SqlValue::NationalStringLiteral(s) => lit(s.clone()), + SqlValue::HexStringLiteral(s) => lit(s.clone()), + SqlValue::DoubleQuotedString(s) => lit(s.clone()), + SqlValue::Boolean(b) => lit(*b), + SqlValue::Null => Expr::Literal(LiteralValue::Null), + _ => { + return Err(PolarsError::ComputeError( + format!( + "Parsing SQL Value {:?} was not supported in polars-sql yet!", + value + ) + .into(), + )) + } + }) +} + +pub fn parse_sql_expr(expr: &SqlExpr) -> Result { + Ok(match expr { + SqlExpr::Identifier(e) => col(&e.value), + SqlExpr::BinaryOp { left, op, right } => { + let left = parse_sql_expr(left)?; + let right = parse_sql_expr(right)?; + binary_op_(left, right, op)? + } + SqlExpr::Function(sql_function) => parse_sql_function(sql_function)?, + SqlExpr::Cast { expr, data_type } => cast_(parse_sql_expr(expr)?, data_type)?, + SqlExpr::Nested(expr) => parse_sql_expr(expr)?, + SqlExpr::Value(value) => literal_expr(value)?, + _ => { + return Err(PolarsError::ComputeError( + format!( + "Expression: {:?} was not supported in polars-sql yet!", + expr + ) + .into(), + )) + } + }) +} + +fn apply_window_spec(expr: Expr, window_spec: &Option) -> Result { + Ok(match &window_spec { + Some(window_spec) => { + // Process for simple window specification, partition by first + let partition_by = window_spec + .partition_by + .iter() + .map(parse_sql_expr) + .collect::>>()?; + expr.over(partition_by) + // Order by and Row range may not be supported at the moment + } + None => expr, + }) +} + +fn parse_sql_function(sql_function: &SQLFunction) -> Result { + use sqlparser::ast::{FunctionArg, FunctionArgExpr}; + // Function name mostly do not have name space, so it mostly take the first args + let function_name = sql_function.name.0[0].value.to_lowercase(); + let args = sql_function + .args + .iter() + .map(|arg| match arg { + FunctionArg::Named { arg, .. } => arg, + FunctionArg::Unnamed(arg) => arg, + }) + .collect::>(); + Ok( + match ( + function_name.as_str(), + args.as_slice(), + sql_function.distinct, + ) { + ("sum", [FunctionArgExpr::Expr(expr)], false) => { + apply_window_spec(parse_sql_expr(expr)?, &sql_function.over)?.sum() + } + ("count", [FunctionArgExpr::Expr(expr)], false) => { + apply_window_spec(parse_sql_expr(expr)?, &sql_function.over)?.count() + } + ("count", [FunctionArgExpr::Expr(expr)], true) => { + apply_window_spec(parse_sql_expr(expr)?, &sql_function.over)?.n_unique() + } + // Special case for wildcard args to count function. + ("count", [FunctionArgExpr::Wildcard], false) => lit(1i32).count(), + _ => { + return Err(PolarsError::ComputeError( + format!( + "Function {:?} with args {:?} was not supported in polars-sql yet!", + function_name, args + ) + .into(), + )) + } + }, + ) +} diff --git a/crates/nu-command/src/dataframe/eager/with_sql.rs b/crates/nu-command/src/dataframe/eager/with_sql.rs new file mode 100644 index 0000000000..90abfeacaf --- /dev/null +++ b/crates/nu-command/src/dataframe/eager/with_sql.rs @@ -0,0 +1,102 @@ +use super::super::values::NuDataFrame; +use crate::dataframe::values::Column; +use crate::dataframe::{eager::SQLContext, values::NuLazyFrame}; +use nu_engine::CallExt; +use nu_protocol::{ + ast::Call, + engine::{Command, EngineState, Stack}, + Category, Example, PipelineData, ShellError, Signature, Span, SyntaxShape, Type, Value, +}; + +// attribution: +// sql_context.rs, and sql_expr.rs were copied from polars-sql. thank you. +// maybe we should just use the crate at some point but it's not published yet. +// https://github.com/pola-rs/polars/tree/master/polars-sql + +#[derive(Clone)] +pub struct WithSql; + +impl Command for WithSql { + fn name(&self) -> &str { + "with-sql" + } + + fn usage(&self) -> &str { + "Query dataframe using SQL. Note: The dataframe is always named df in your query." + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .required("sql", SyntaxShape::String, "sql query") + .input_type(Type::Custom("dataframe".into())) + .output_type(Type::Custom("dataframe".into())) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Query dataframe using SQL", + example: "[[a b]; [1 2] [3 4]] | into df | with-sql 'select a from df'", + result: Some( + NuDataFrame::try_from_columns(vec![Column::new( + "a".to_string(), + vec![Value::test_int(1), Value::test_int(3)], + )]) + .expect("simple df for test should not fail") + .into_value(Span::test_data()), + ), + }] + } + + fn run( + &self, + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, + input: PipelineData, + ) -> Result { + command(engine_state, stack, call, input) + } +} + +fn command( + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, + input: PipelineData, +) -> Result { + let sql_query: String = call.req(engine_state, stack, 0)?; + let df = NuDataFrame::try_from_pipeline(input, call.head)?; + + let mut ctx = SQLContext::new(); + ctx.register("df", &df.df); + let df_sql = ctx.execute(&sql_query).map_err(|e| { + ShellError::GenericError( + "Dataframe Error".into(), + e.to_string(), + Some(call.head), + None, + Vec::new(), + ) + })?; + let lazy = NuLazyFrame::new(false, df_sql); + + let eager = lazy.collect(call.head)?; + let value = Value::CustomValue { + val: Box::new(eager), + span: call.head, + }; + + Ok(PipelineData::Value(value, None)) +} + +#[cfg(test)] +mod test { + use super::super::super::test_dataframe::test_dataframe; + use super::*; + + #[test] + fn test_examples() { + test_dataframe(vec![Box::new(WithSql {})]) + } +}