Make seq return a ListStream where possible (#7367)

# Description

Title.

# User-Facing Changes

Faster seq that works better with functions that take in `ListStream`s.

# Tests + Formatting

Don't forget to add tests that cover your changes.

Make sure you've run and fixed any issues with these commands:

- `cargo fmt --all -- --check` to check standard code formatting (`cargo
fmt --all` applies these changes)
- `cargo clippy --workspace -- -D warnings -D clippy::unwrap_used -A
clippy::needless_collect` to check that you're using the standard code
style
- `cargo test --workspace` to check that all tests pass

# After Submitting

If your PR had any user-facing changes, update [the
documentation](https://github.com/nushell/nushell.github.io) after the
PR is merged, if necessary. This will help us keep the docs up to date.
This commit is contained in:
pwygab 2022-12-07 10:48:03 +08:00 committed by GitHub
parent df66d9fcdf
commit 3395beaa56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 48 deletions

View File

@ -2,8 +2,8 @@ use nu_engine::CallExt;
use nu_protocol::{ use nu_protocol::{
ast::Call, ast::Call,
engine::{Command, EngineState, Stack}, engine::{Command, EngineState, Stack},
Category, Example, IntoPipelineData, PipelineData, ShellError, Signature, Span, Spanned, Category, Example, PipelineData, ShellError, Signature, Span, Spanned, SyntaxShape, Type,
SyntaxShape, Type, Value, Value,
}; };
#[derive(Clone)] #[derive(Clone)]
@ -93,6 +93,13 @@ fn seq(
let span = call.head; let span = call.head;
let rest_nums: Vec<Spanned<f64>> = call.rest(engine_state, stack, 0)?; let rest_nums: Vec<Spanned<f64>> = call.rest(engine_state, stack, 0)?;
// note that the check for int or float has to occur here. prior, the check would occur after
// everything had been generated; this does not work well with ListStreams.
// As such, the simple test is to check if this errors out: that means there is a float in the
// input, which necessarily means that parts of the output will be floats.
let rest_nums_check: Result<Vec<Spanned<i64>>, ShellError> = call.rest(engine_state, stack, 0);
let contains_decimals = rest_nums_check.is_err();
if rest_nums.is_empty() { if rest_nums.is_empty() {
return Err(ShellError::GenericError( return Err(ShellError::GenericError(
"seq requires some parameters".into(), "seq requires some parameters".into(),
@ -105,7 +112,7 @@ fn seq(
let rest_nums: Vec<f64> = rest_nums.iter().map(|n| n.item).collect(); let rest_nums: Vec<f64> = rest_nums.iter().map(|n| n.item).collect();
run_seq(rest_nums, span) run_seq(rest_nums, span, contains_decimals, engine_state)
} }
#[cfg(test)] #[cfg(test)]
@ -120,60 +127,92 @@ mod tests {
} }
} }
pub fn run_seq(free: Vec<f64>, span: Span) -> Result<PipelineData, ShellError> { pub fn run_seq(
free: Vec<f64>,
span: Span,
contains_decimals: bool,
engine_state: &EngineState,
) -> Result<PipelineData, ShellError> {
let first = free[0]; let first = free[0];
let step = if free.len() > 2 { free[1] } else { 1.0 };
let step: f64 = if free.len() > 2 { free[1] } else { 1.0 };
let last = { free[free.len() - 1] }; let last = { free[free.len() - 1] };
Ok(print_seq(first, step, last, span)) if !contains_decimals {
} // integers only
Ok(PipelineData::ListStream(
fn done_printing(next: f64, step: f64, last: f64) -> bool { nu_protocol::ListStream {
if step >= 0f64 { stream: Box::new(IntSeq {
next > last count: first as i64,
step: step as i64,
last: last as i64,
span,
}),
ctrlc: engine_state.ctrlc.clone(),
},
None,
))
} else { } else {
next < last // floats
Ok(PipelineData::ListStream(
nu_protocol::ListStream {
stream: Box::new(FloatSeq {
first,
step,
last,
index: 0,
span,
}),
ctrlc: engine_state.ctrlc.clone(),
},
None,
))
} }
} }
fn print_seq(first: f64, step: f64, last: f64, span: Span) -> PipelineData { struct FloatSeq {
let mut i = 0isize; first: f64,
let mut value = first + i as f64 * step; step: f64,
let mut ret_num = vec![]; last: f64,
index: isize,
while !done_printing(value, step, last) { span: Span,
ret_num.push(value);
i += 1;
value = first + i as f64 * step;
}
// we'd like to keep the datatype the same for the output, so check
// and see if any of the output contains values after the decimal point,
// and if so we'll make the entire output floats
let contains_decimals = vec_contains_decimals(&ret_num);
let rows: Vec<Value> = ret_num
.iter()
.map(|v| {
if contains_decimals {
Value::float(*v, span)
} else {
Value::int(*v as i64, span)
}
})
.collect();
Value::List { vals: rows, span }.into_pipeline_data()
} }
fn vec_contains_decimals(array: &[f64]) -> bool { impl Iterator for FloatSeq {
let mut found_decimal = false; type Item = Value;
for x in array { fn next(&mut self) -> Option<Value> {
if x.fract() != 0.0 { let count = self.first + self.index as f64 * self.step;
found_decimal = true; // Accuracy guaranteed as far as possible; each time, the value is re-evaluated from the
break; // base arguments
if (count > self.last && self.step >= 0.0) || (count < self.last && self.step <= 0.0) {
return None;
} }
self.index += 1;
Some(Value::Float {
val: count,
span: self.span,
})
}
}
struct IntSeq {
count: i64,
step: i64,
last: i64,
span: Span,
}
impl Iterator for IntSeq {
type Item = Value;
fn next(&mut self) -> Option<Value> {
if (self.count > self.last && self.step >= 0) || (self.count < self.last && self.step <= 0)
{
return None;
}
let ret = Some(Value::Int {
val: self.count,
span: self.span,
});
self.count += self.step;
ret
} }
found_decimal
} }

View File

@ -72,6 +72,7 @@ mod run_external;
mod save; mod save;
mod select; mod select;
mod semicolon; mod semicolon;
mod seq;
mod seq_char; mod seq_char;
mod shells; mod shells;
mod skip; mod skip;

View File

@ -0,0 +1,25 @@
use nu_test_support::{nu, pipeline};
#[test]
fn float_in_seq_leads_to_lists_of_floats() {
let actual = nu!(
cwd: "tests/fixtures/formats", pipeline(
r#"
seq 1.0 0.5 6 | describe
"#
));
assert_eq!(actual.out, "list<float>");
}
#[test]
fn ints_in_seq_leads_to_lists_of_ints() {
let actual = nu!(
cwd: "tests/fixtures/formats", pipeline(
r#"
seq 1 2 6 | describe
"#
));
assert_eq!(actual.out, "list<int>");
}