mirror of
https://github.com/nushell/nushell.git
synced 2025-08-09 22:37:45 +02:00
Groupby operations on dataframes (#3473)
* Added PolarsStruct enum to implement groupby * template groupby * groupby operationi on dataframes
This commit is contained in:
@ -236,7 +236,7 @@ pub fn autoview(args: CommandArgs) -> Result<OutputStream, ShellError> {
|
||||
}
|
||||
#[cfg(feature = "dataframe")]
|
||||
Value {
|
||||
value: UntaggedValue::Dataframe(df),
|
||||
value: UntaggedValue::DataFrame(df),
|
||||
..
|
||||
} => {
|
||||
if let Some(table) = table {
|
||||
|
@ -23,7 +23,7 @@ impl WholeStreamCommand for Command {
|
||||
let args = args.evaluate_once()?;
|
||||
|
||||
let df = NuDataFrame::try_from_iter(args.input, &tag)?;
|
||||
let init = InputStream::one(UntaggedValue::Dataframe(df).into_value(&tag));
|
||||
let init = InputStream::one(UntaggedValue::DataFrame(df).into_value(&tag));
|
||||
|
||||
Ok(init.to_output_stream())
|
||||
}
|
||||
|
262
crates/nu-command/src/commands/dataframe/groupby.rs
Normal file
262
crates/nu-command/src/commands/dataframe/groupby.rs
Normal file
@ -0,0 +1,262 @@
|
||||
use crate::prelude::*;
|
||||
use nu_engine::WholeStreamCommand;
|
||||
use nu_errors::ShellError;
|
||||
use nu_protocol::{
|
||||
dataframe::NuDataFrame, Primitive, Signature, SyntaxShape, UntaggedValue, Value,
|
||||
};
|
||||
use nu_source::Tagged;
|
||||
use polars::frame::groupby::GroupBy;
|
||||
|
||||
enum Operation {
|
||||
Mean,
|
||||
Sum,
|
||||
Min,
|
||||
Max,
|
||||
First,
|
||||
Last,
|
||||
Nunique,
|
||||
Quantile(f64),
|
||||
Median,
|
||||
//Var,
|
||||
//Std,
|
||||
Count,
|
||||
}
|
||||
|
||||
impl Operation {
|
||||
fn from_tagged(
|
||||
name: &Tagged<String>,
|
||||
quantile: Option<Tagged<f64>>,
|
||||
) -> Result<Operation, ShellError> {
|
||||
match name.item.as_ref() {
|
||||
"mean" => Ok(Operation::Mean),
|
||||
"sum" => Ok(Operation::Sum),
|
||||
"min" => Ok(Operation::Min),
|
||||
"max" => Ok(Operation::Max),
|
||||
"first" => Ok(Operation::First),
|
||||
"last" => Ok(Operation::Last),
|
||||
"nunique" => Ok(Operation::Nunique),
|
||||
"quantile" => {
|
||||
match quantile {
|
||||
None => Err(ShellError::labeled_error(
|
||||
"Quantile value not fount",
|
||||
"Quantile operation requires quantile value",
|
||||
&name.tag,
|
||||
)),
|
||||
Some(value ) => {
|
||||
if (value.item < 0.0) | (value.item > 1.0) {
|
||||
Err(ShellError::labeled_error(
|
||||
"Inappropriate quantile",
|
||||
"Quantile value should be between 0.0 and 1.0",
|
||||
&value.tag,
|
||||
))
|
||||
} else {
|
||||
Ok(Operation::Quantile(value.item))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"median" => Ok(Operation::Median),
|
||||
//"var" => Ok(Operation::Var),
|
||||
//"std" => Ok(Operation::Std),
|
||||
"count" => Ok(Operation::Count),
|
||||
_ => Err(ShellError::labeled_error_with_secondary(
|
||||
"Operation not fount",
|
||||
"Operation does not exist",
|
||||
&name.tag,
|
||||
"Perhaps you want: mean, sum, min, max, first, last, nunique, quantile, median, count",
|
||||
&name.tag,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DataFrame;
|
||||
|
||||
impl WholeStreamCommand for DataFrame {
|
||||
fn name(&self) -> &str {
|
||||
"dataframe groupby"
|
||||
}
|
||||
|
||||
fn usage(&self) -> &str {
|
||||
"Creates a groupby operation on a dataframe"
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::build("dataframe groupby")
|
||||
.required("columns", SyntaxShape::Table, "groupby columns")
|
||||
.required(
|
||||
"aggregation columns",
|
||||
SyntaxShape::Table,
|
||||
"columns to perform aggregation",
|
||||
)
|
||||
.required("operation", SyntaxShape::String, "aggregate operation")
|
||||
.named(
|
||||
"quantile",
|
||||
SyntaxShape::Number,
|
||||
"auantile value for quantile operation",
|
||||
Some('q'),
|
||||
)
|
||||
}
|
||||
|
||||
fn run(&self, args: CommandArgs) -> Result<OutputStream, ShellError> {
|
||||
groupby(args)
|
||||
}
|
||||
|
||||
fn examples(&self) -> Vec<Example> {
|
||||
vec![Example {
|
||||
description: "",
|
||||
example: "",
|
||||
result: None,
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
fn groupby(args: CommandArgs) -> Result<OutputStream, ShellError> {
|
||||
let tag = args.call_info.name_tag.clone();
|
||||
let mut args = args.evaluate_once()?;
|
||||
|
||||
let quantile: Option<Tagged<f64>> = args.get_flag("quantile")?;
|
||||
let operation: Tagged<String> = args.req(2)?;
|
||||
let op = Operation::from_tagged(&operation, quantile)?;
|
||||
|
||||
// Extracting the names of the columns to perform the groupby
|
||||
let columns: Vec<Value> = args.req(0)?;
|
||||
|
||||
// Extracting the first tag from the groupby column names
|
||||
let mut col_span = match columns
|
||||
.iter()
|
||||
.nth(0)
|
||||
.map(|v| Span::new(v.tag.span.start(), v.tag.span.end()))
|
||||
{
|
||||
Some(span) => span,
|
||||
None => {
|
||||
return Err(ShellError::labeled_error(
|
||||
"Empty groupby names list",
|
||||
"Empty list for groupby column names",
|
||||
&tag,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let columns_string = columns
|
||||
.into_iter()
|
||||
.map(|value| match value.value {
|
||||
UntaggedValue::Primitive(Primitive::String(s)) => {
|
||||
col_span = col_span.until(value.tag.span);
|
||||
Ok(s)
|
||||
}
|
||||
_ => Err(ShellError::labeled_error(
|
||||
"Incorrect column format",
|
||||
"Only string as column name",
|
||||
&value.tag,
|
||||
)),
|
||||
})
|
||||
.collect::<Result<Vec<String>, _>>()?;
|
||||
|
||||
// Extracting the names of the columns to perform the aggregation
|
||||
let agg_cols: Vec<Value> = args.req(1)?;
|
||||
|
||||
// Extracting the first tag from the aggregation column names
|
||||
let mut agg_span = match agg_cols
|
||||
.iter()
|
||||
.nth(0)
|
||||
.map(|v| Span::new(v.tag.span.start(), v.tag.span.end()))
|
||||
{
|
||||
Some(span) => span,
|
||||
None => {
|
||||
return Err(ShellError::labeled_error(
|
||||
"Empty aggregation names list",
|
||||
"Empty list for aggregation column names",
|
||||
&tag,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let aggregation_string = agg_cols
|
||||
.into_iter()
|
||||
.map(|value| match value.value {
|
||||
UntaggedValue::Primitive(Primitive::String(s)) => {
|
||||
agg_span = agg_span.until(value.tag.span);
|
||||
Ok(s)
|
||||
}
|
||||
_ => Err(ShellError::labeled_error(
|
||||
"Incorrect column format",
|
||||
"Only string as column name",
|
||||
value.tag,
|
||||
)),
|
||||
})
|
||||
.collect::<Result<Vec<String>, _>>()?;
|
||||
|
||||
// The operation is only done in one dataframe. Only one input is
|
||||
// expected from the InputStream
|
||||
match args.input.next() {
|
||||
None => Err(ShellError::labeled_error(
|
||||
"No input received",
|
||||
"missing dataframe input from stream",
|
||||
&tag,
|
||||
)),
|
||||
Some(value) => {
|
||||
if let UntaggedValue::DataFrame(NuDataFrame {
|
||||
dataframe: Some(df),
|
||||
..
|
||||
}) = value.value
|
||||
{
|
||||
let groupby = df
|
||||
.groupby(&columns_string)
|
||||
.map_err(|e| {
|
||||
ShellError::labeled_error("Groupby error", format!("{}", e), col_span)
|
||||
})?
|
||||
.select(&aggregation_string);
|
||||
|
||||
let res = perform_aggregation(groupby, op, &operation.tag, &agg_span)?;
|
||||
|
||||
let final_df = Value {
|
||||
tag,
|
||||
value: UntaggedValue::DataFrame(NuDataFrame {
|
||||
dataframe: Some(res),
|
||||
name: "agg result".to_string(),
|
||||
}),
|
||||
};
|
||||
|
||||
Ok(OutputStream::one(final_df))
|
||||
} else {
|
||||
Err(ShellError::labeled_error(
|
||||
"No dataframe in stream",
|
||||
"no dataframe found in input stream",
|
||||
&tag,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn perform_aggregation(
|
||||
groupby: GroupBy,
|
||||
operation: Operation,
|
||||
operation_tag: &Tag,
|
||||
agg_span: &Span,
|
||||
) -> Result<polars::prelude::DataFrame, ShellError> {
|
||||
match operation {
|
||||
Operation::Mean => groupby.mean(),
|
||||
Operation::Sum => groupby.sum(),
|
||||
Operation::Min => groupby.min(),
|
||||
Operation::Max => groupby.max(),
|
||||
Operation::First => groupby.first(),
|
||||
Operation::Last => groupby.last(),
|
||||
Operation::Nunique => groupby.n_unique(),
|
||||
Operation::Quantile(quantile) => groupby.quantile(quantile),
|
||||
Operation::Median => groupby.median(),
|
||||
//Operation::Var => groupby.var(),
|
||||
//Operation::Std => groupby.std(),
|
||||
Operation::Count => groupby.count(),
|
||||
}
|
||||
.map_err(|e| {
|
||||
let span = if e.to_string().contains("Not found") {
|
||||
agg_span
|
||||
} else {
|
||||
&operation_tag.span
|
||||
};
|
||||
|
||||
ShellError::labeled_error("Aggregation error", format!("{}", e), span)
|
||||
})
|
||||
}
|
@ -3,9 +3,9 @@ use nu_engine::WholeStreamCommand;
|
||||
use nu_errors::ShellError;
|
||||
use nu_protocol::{Signature, TaggedDictBuilder, UntaggedValue, Value};
|
||||
|
||||
pub struct Dataframe;
|
||||
pub struct DataFrame;
|
||||
|
||||
impl WholeStreamCommand for Dataframe {
|
||||
impl WholeStreamCommand for DataFrame {
|
||||
fn name(&self) -> &str {
|
||||
"dataframe list"
|
||||
}
|
||||
@ -23,7 +23,7 @@ impl WholeStreamCommand for Dataframe {
|
||||
|
||||
let mut dataframes: Vec<Value> = Vec::new();
|
||||
for (name, value) in args.context.scope.get_vars() {
|
||||
if let UntaggedValue::Dataframe(df) = value.value {
|
||||
if let UntaggedValue::DataFrame(df) = value.value {
|
||||
let mut data = TaggedDictBuilder::new(value.tag);
|
||||
|
||||
let polars_df = df.dataframe.unwrap();
|
||||
|
@ -11,9 +11,9 @@ use nu_source::Tagged;
|
||||
use polars::prelude::{CsvReader, JsonReader, ParquetReader, SerReader};
|
||||
use std::fs::File;
|
||||
|
||||
pub struct Dataframe;
|
||||
pub struct DataFrame;
|
||||
|
||||
impl WholeStreamCommand for Dataframe {
|
||||
impl WholeStreamCommand for DataFrame {
|
||||
fn name(&self) -> &str {
|
||||
"dataframe load"
|
||||
}
|
||||
@ -112,7 +112,7 @@ fn create_from_file(args: CommandArgs) -> Result<OutputStream, ShellError> {
|
||||
name: file_name,
|
||||
};
|
||||
|
||||
let init = InputStream::one(UntaggedValue::Dataframe(nu_dataframe).into_value(&tag));
|
||||
let init = InputStream::one(UntaggedValue::DataFrame(nu_dataframe).into_value(&tag));
|
||||
|
||||
Ok(init.to_output_stream())
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
pub mod command;
|
||||
pub mod groupby;
|
||||
pub mod list;
|
||||
pub mod load;
|
||||
|
||||
pub use command::Command as Dataframe;
|
||||
pub use list::Dataframe as DataframeList;
|
||||
pub use load::Dataframe as DataframeLoad;
|
||||
pub use command::Command as DataFrame;
|
||||
pub use groupby::DataFrame as DataFrameGroupBy;
|
||||
pub use list::DataFrame as DataFrameList;
|
||||
pub use load::DataFrame as DataFrameLoad;
|
||||
|
@ -253,11 +253,13 @@ pub fn create_default_context(interactive: bool) -> Result<EvaluationContext, Bo
|
||||
whole_stream_command(SeqDates),
|
||||
whole_stream_command(TermSize),
|
||||
#[cfg(feature = "dataframe")]
|
||||
whole_stream_command(Dataframe),
|
||||
whole_stream_command(DataFrame),
|
||||
#[cfg(feature = "dataframe")]
|
||||
whole_stream_command(DataframeLoad),
|
||||
whole_stream_command(DataFrameLoad),
|
||||
#[cfg(feature = "dataframe")]
|
||||
whole_stream_command(DataframeList),
|
||||
whole_stream_command(DataFrameList),
|
||||
#[cfg(feature = "dataframe")]
|
||||
whole_stream_command(DataFrameGroupBy),
|
||||
]);
|
||||
|
||||
#[cfg(feature = "clipboard-cli")]
|
||||
|
@ -115,7 +115,7 @@ pub fn value_to_json_value(v: &Value) -> Result<serde_json::Value, ShellError> {
|
||||
serde_json::Value::Null
|
||||
}
|
||||
#[cfg(feature = "dataframe")]
|
||||
UntaggedValue::Dataframe(_) => serde_json::Value::Null,
|
||||
UntaggedValue::DataFrame(_) => serde_json::Value::Null,
|
||||
UntaggedValue::Primitive(Primitive::Binary(b)) => serde_json::Value::Array(
|
||||
b.iter()
|
||||
.map(|x| {
|
||||
|
@ -74,7 +74,7 @@ fn helper(v: &Value) -> Result<toml::Value, ShellError> {
|
||||
UntaggedValue::Error(e) => return Err(e.clone()),
|
||||
UntaggedValue::Block(_) => toml::Value::String("<Block>".to_string()),
|
||||
#[cfg(feature = "dataframe")]
|
||||
UntaggedValue::Dataframe(_) => toml::Value::String("<Dataframe>".to_string()),
|
||||
UntaggedValue::DataFrame(_) => toml::Value::String("<Data>".to_string()),
|
||||
UntaggedValue::Primitive(Primitive::Range(_)) => toml::Value::String("<Range>".to_string()),
|
||||
UntaggedValue::Primitive(Primitive::Binary(b)) => {
|
||||
toml::Value::Array(b.iter().map(|x| toml::Value::Integer(*x as i64)).collect())
|
||||
|
@ -96,7 +96,7 @@ pub fn value_to_yaml_value(v: &Value) -> Result<serde_yaml::Value, ShellError> {
|
||||
serde_yaml::Value::Null
|
||||
}
|
||||
#[cfg(feature = "dataframe")]
|
||||
UntaggedValue::Dataframe(_) => serde_yaml::Value::Null,
|
||||
UntaggedValue::DataFrame(_) => serde_yaml::Value::Null,
|
||||
UntaggedValue::Primitive(Primitive::Binary(b)) => serde_yaml::Value::Sequence(
|
||||
b.iter()
|
||||
.map(|x| serde_yaml::Value::Number(serde_yaml::Number::from(*x)))
|
||||
|
@ -156,9 +156,9 @@ fn uniq(args: CommandArgs) -> Result<ActionStream, ShellError> {
|
||||
))
|
||||
}
|
||||
#[cfg(feature = "dataframe")]
|
||||
UntaggedValue::Dataframe(_) => {
|
||||
UntaggedValue::DataFrame(_) => {
|
||||
return Err(ShellError::labeled_error(
|
||||
"uniq -c cannot operate on dataframes.",
|
||||
"uniq -c cannot operate on data structs",
|
||||
"source",
|
||||
item.0.tag.span,
|
||||
))
|
||||
|
Reference in New Issue
Block a user