"[11611] fixing dataframe column comparisons" (#11676)

fixes #11611

Co-authored-by: Jack Wright <jack.wright@disqo.com>
This commit is contained in:
Jack Wright 2024-01-29 15:28:12 -08:00 committed by GitHub
parent 798ae7b251
commit 175dab4898
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -84,27 +84,27 @@ pub(super) fn compute_between_series(
}
Operator::Comparison(Comparison::NotEqual) => {
let name = format!("neq_{}_{}", lhs.name(), rhs.name());
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::equal)?;
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::not_equal)?;
NuDataFrame::series_to_value(res, operation_span)
}
Operator::Comparison(Comparison::LessThan) => {
let name = format!("lt_{}_{}", lhs.name(), rhs.name());
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::equal)?;
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::lt)?;
NuDataFrame::series_to_value(res, operation_span)
}
Operator::Comparison(Comparison::LessThanOrEqual) => {
let name = format!("lte_{}_{}", lhs.name(), rhs.name());
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::equal)?;
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::lt_eq)?;
NuDataFrame::series_to_value(res, operation_span)
}
Operator::Comparison(Comparison::GreaterThan) => {
let name = format!("gt_{}_{}", lhs.name(), rhs.name());
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::equal)?;
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::gt)?;
NuDataFrame::series_to_value(res, operation_span)
}
Operator::Comparison(Comparison::GreaterThanOrEqual) => {
let name = format!("gte_{}_{}", lhs.name(), rhs.name());
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::equal)?;
let res = compare_series(lhs, rhs, name.as_str(), right.span(), Series::gt_eq)?;
NuDataFrame::series_to_value(res, operation_span)
}
Operator::Boolean(Boolean::And) => match lhs.dtype() {
@ -768,3 +768,117 @@ fn add_string_to_series(series: &Series, pat: &str, span: Span) -> Result<Value,
}),
}
}
#[cfg(test)]
mod test {
use super::*;
use nu_protocol::Span;
use polars::{prelude::NamedFrom, series::Series};
use crate::dataframe::values::NuDataFrame;
#[test]
fn test_compute_between_series_comparisons() {
let series = Series::new("c", &[1, 2]);
let df = NuDataFrame::try_from_series(vec![series], Span::test_data())
.expect("should be able to create a simple dataframe");
let c0 = df
.column("c", Span::test_data())
.expect("should be able to get column c");
let c0_series = c0
.as_series(Span::test_data())
.expect("should be able to get series");
let c0_value = c0.into_value(Span::test_data());
let c1 = df
.column("c", Span::test_data())
.expect("should be able to get column c");
let c1_series = c1
.as_series(Span::test_data())
.expect("should be able to get series");
let c1_value = c1.into_value(Span::test_data());
let op = Spanned {
item: Operator::Comparison(Comparison::NotEqual),
span: Span::test_data(),
};
let result = compute_between_series(op, &c0_value, &c0_series, &c1_value, &c1_series)
.expect("compare should not fail");
let result = NuDataFrame::try_from_value(result)
.expect("should be able to create a dataframe from a value");
let result = result
.as_series(Span::test_data())
.expect("should be convert to a series");
assert_eq!(result, Series::new("neq_c_c", &[false, false]));
let op = Spanned {
item: Operator::Comparison(Comparison::Equal),
span: Span::test_data(),
};
let result = compute_between_series(op, &c0_value, &c0_series, &c1_value, &c1_series)
.expect("compare should not fail");
let result = NuDataFrame::try_from_value(result)
.expect("should be able to create a dataframe from a value");
let result = result
.as_series(Span::test_data())
.expect("should be convert to a series");
assert_eq!(result, Series::new("eq_c_c", &[true, true]));
let op = Spanned {
item: Operator::Comparison(Comparison::LessThan),
span: Span::test_data(),
};
let result = compute_between_series(op, &c0_value, &c0_series, &c1_value, &c1_series)
.expect("compare should not fail");
let result = NuDataFrame::try_from_value(result)
.expect("should be able to create a dataframe from a value");
let result = result
.as_series(Span::test_data())
.expect("should be convert to a series");
assert_eq!(result, Series::new("lt_c_c", &[false, false]));
let op = Spanned {
item: Operator::Comparison(Comparison::LessThanOrEqual),
span: Span::test_data(),
};
let result = compute_between_series(op, &c0_value, &c0_series, &c1_value, &c1_series)
.expect("compare should not fail");
let result = NuDataFrame::try_from_value(result)
.expect("should be able to create a dataframe from a value");
let result = result
.as_series(Span::test_data())
.expect("should be convert to a series");
assert_eq!(result, Series::new("lte_c_c", &[true, true]));
let op = Spanned {
item: Operator::Comparison(Comparison::GreaterThan),
span: Span::test_data(),
};
let result = compute_between_series(op, &c0_value, &c0_series, &c1_value, &c1_series)
.expect("compare should not fail");
let result = NuDataFrame::try_from_value(result)
.expect("should be able to create a dataframe from a value");
let result = result
.as_series(Span::test_data())
.expect("should be convert to a series");
assert_eq!(result, Series::new("gt_c_c", &[false, false]));
let op = Spanned {
item: Operator::Comparison(Comparison::GreaterThanOrEqual),
span: Span::test_data(),
};
let result = compute_between_series(op, &c0_value, &c0_series, &c1_value, &c1_series)
.expect("compare should not fail");
let result = NuDataFrame::try_from_value(result)
.expect("should be able to create a dataframe from a value");
let result = result
.as_series(Span::test_data())
.expect("should be convert to a series");
assert_eq!(result, Series::new("gte_c_c", &[true, true]));
}
}