diff --git a/crates/nu-command/src/bytes/at.rs b/crates/nu-command/src/bytes/at.rs index a5b411de60..85f6dbe4c2 100644 --- a/crates/nu-command/src/bytes/at.rs +++ b/crates/nu-command/src/bytes/at.rs @@ -1,11 +1,9 @@ -use itertools::Itertools; use nu_cmd_base::{ input_handler::{operate, CmdArgument}, util, }; use nu_engine::command_prelude::*; use nu_protocol::Range; -use std::io::{Read, Write}; #[derive(Clone)] pub struct BytesAt; @@ -87,7 +85,15 @@ impl Command for BytesAt { }; if let PipelineData::ByteStream(stream, metadata) = input { - handle_byte_stream(&args, stream, call, metadata, engine_state) + match stream.slice( + call.head, + call.arguments_span(), + args.indexes.0, + args.indexes.1, + ) { + Ok(stream) => Ok(PipelineData::ByteStream(stream, metadata)), + Err(err) => Err(err), + } } else { operate(map_value, args, input, call.head, engine_state.signals()) } @@ -104,8 +110,8 @@ impl Command for BytesAt { }, Example { description: "Slice out `0x[10 01 13]` from `0x[33 44 55 10 01 13]`", - example: "0x[33 44 55 10 01 13 10] | bytes at 3..6", - result: Some(Value::test_binary(vec![0x10, 0x01, 0x13, 0x10])), + example: "0x[33 44 55 10 01 13] | bytes at 3..6", + result: Some(Value::test_binary(vec![0x10, 0x01, 0x13])), }, Example { description: "Extract bytes from the start up to a specific index", @@ -145,7 +151,7 @@ fn map_value(input: &Value, args: &Arguments, head: Span) -> Value { match input { Value::Binary { val, .. } => { let (start, end) = resolve_relative_range(range, val.len()); - let iter = val.iter().map(|x| *x); + let iter = val.iter().copied(); let bytes: Vec = if start > end { vec![] @@ -170,53 +176,6 @@ fn map_value(input: &Value, args: &Arguments, head: Span) -> Value { } } -fn handle_byte_stream( - args: &Arguments, - stream: ByteStream, - call: &Call, - metadata: Option, - engine_state: &EngineState, -) -> Result { - match stream.reader() { - Some(reader) => { - let iter = reader.bytes().filter_map(Result::ok); - let Subbytes { 0: start, 1: end } = args.indexes; - - let mut iter = if start < 0 || end < 0 { - match iter.try_len() { - Ok(len) => { - let (start, end) = resolve_relative_range(&args.indexes, len); - iter.skip(start).take(end - start + 1) - }, - Err(_) => return Err(ShellError::IncorrectValue { - msg: "Negative range values cannot be used with streams that don't specify a length".into(), - val_span: call.head, - call_span: call.arguments_span(), - }), - } - } else { - iter.skip(start as usize).take((end - start) as usize) - }; - - let stream = ByteStream::from_fn( - call.head, - engine_state.signals().clone(), - ByteStreamType::Binary, - move |buf| match iter.next() { - Some(n) if n > 0 => match buf.write(&[n]) { - Ok(_) => Ok(true), - Err(err) => Err(err.into()), - }, - _ => Ok(false), - }, - ); - - Ok(PipelineData::ByteStream(stream, metadata)) - } - None => Ok(PipelineData::empty()), - } -} - fn resolve_relative_range(range: &Subbytes, len: usize) -> (usize, usize) { let start = match range.0 { start if start < 0 => match len as isize + start { @@ -233,3 +192,14 @@ fn resolve_relative_range(range: &Subbytes, len: usize) -> (usize, usize) { (start, end) } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_examples() { + use crate::test_examples; + test_examples(BytesAt {}) + } +} diff --git a/crates/nu-command/src/filters/skip/skip_.rs b/crates/nu-command/src/filters/skip/skip_.rs index fade5a00ba..183fb9872a 100644 --- a/crates/nu-command/src/filters/skip/skip_.rs +++ b/crates/nu-command/src/filters/skip/skip_.rs @@ -1,6 +1,4 @@ use nu_engine::command_prelude::*; -use nu_protocol::Signals; -use std::io::{self, Read}; #[derive(Clone)] pub struct Skip; @@ -96,21 +94,9 @@ impl Command for Skip { PipelineData::ByteStream(stream, metadata) => { if stream.type_().is_binary_coercible() { let span = stream.span(); - if let Some(mut reader) = stream.reader() { - // Copy the number of skipped bytes into the sink before proceeding - io::copy(&mut (&mut reader).take(n as u64), &mut io::sink()) - .err_span(span)?; - Ok(PipelineData::ByteStream( - ByteStream::read( - reader, - call.head, - Signals::empty(), - ByteStreamType::Binary, - ), - metadata, - )) - } else { - Ok(PipelineData::Empty) + match stream.skip(span, n as u64) { + Ok(stream) => Ok(PipelineData::ByteStream(stream, metadata)), + Err(err) => Err(err), } } else { Err(ShellError::OnlySupportsThisInputType { diff --git a/crates/nu-command/src/filters/take/take_.rs b/crates/nu-command/src/filters/take/take_.rs index 85a94fcd4b..4f48a8aa5f 100644 --- a/crates/nu-command/src/filters/take/take_.rs +++ b/crates/nu-command/src/filters/take/take_.rs @@ -1,6 +1,5 @@ use nu_engine::command_prelude::*; use nu_protocol::Signals; -use std::io::Read; #[derive(Clone)] pub struct Take; @@ -89,19 +88,10 @@ impl Command for Take { )), PipelineData::ByteStream(stream, metadata) => { if stream.type_().is_binary_coercible() { - if let Some(reader) = stream.reader() { - // Just take 'rows' bytes off the stream, mimicking the binary behavior - Ok(PipelineData::ByteStream( - ByteStream::read( - reader.take(rows_desired as u64), - head, - Signals::empty(), - ByteStreamType::Binary, - ), - metadata, - )) - } else { - Ok(PipelineData::Empty) + let span = stream.span(); + match stream.take(span, rows_desired as u64) { + Ok(stream) => Ok(PipelineData::ByteStream(stream, metadata)), + Err(err) => Err(err), } } else { Err(ShellError::OnlySupportsThisInputType { diff --git a/crates/nu-command/tests/commands/bytes/at.rs b/crates/nu-command/tests/commands/bytes/at.rs index 8e107ce631..32514b36fe 100644 --- a/crates/nu-command/tests/commands/bytes/at.rs +++ b/crates/nu-command/tests/commands/bytes/at.rs @@ -1,21 +1,63 @@ use nu_test_support::nu; #[test] -fn returns_error_for_relative_range_on_infinite_stream() { +pub fn returns_error_for_relative_range_on_infinite_stream() { let actual = nu!("nu --testbin iecho 3 | bytes at ..-3"); assert!( actual.err.contains( - "Negative range values cannot be used with streams that don't specify a length" + "Relative range values cannot be used with streams that don't specify a length" ), "Expected error message for negative range with infinite stream" ); } #[test] -fn returns_bytes_for_fixed_range_on_infinite_stream() { +pub fn returns_bytes_for_fixed_range_on_infinite_stream() { let actual = nu!("nu --testbin iecho 3 | bytes at ..10 | decode"); assert_eq!( actual.out, "33333", "Expected bytes from index 1 to 10, but got different output" ); } + +#[test] +pub fn test_string_returns_correct_slice_for_simple_positive_slice() { + let actual = nu!("\"Hello World\" | encode utf8 | bytes at ..4 | decode"); + assert_eq!(actual.out, "Hello"); +} + +#[test] +pub fn test_string_returns_correct_slice_for_negative_start() { + let actual = nu!("\"Hello World\" | encode utf8 | bytes at 6..11 | decode"); + assert_eq!(actual.out, "World"); +} + +#[test] +pub fn test_string_returns_correct_slice_for_negative_end() { + let actual = nu!("\"Hello World\" | encode utf8 | bytes at ..-7 | decode"); + assert_eq!(actual.out, "Hello"); +} + +#[test] +pub fn test_string_returns_correct_slice_for_empty_slice() { + let actual = nu!("\"Hello World\" | encode utf8 | bytes at 5..<5 | decode"); + assert_eq!(actual.out, ""); +} + +#[test] +pub fn test_string_returns_correct_slice_for_out_of_bounds() { + let actual = nu!("\"Hello World\" | encode utf8 | bytes at 0..20 | decode"); + assert_eq!(actual.out, "Hello World"); +} + +#[test] +pub fn test_string_returns_correct_slice_for_invalid_range() { + let actual = nu!("\"Hello World\" | encode utf8 | bytes at 11..5 | decode"); + assert_eq!(actual.out, ""); +} + +#[test] +pub fn test_string_returns_correct_slice_for_max_end() { + let actual = nu!("\"Hello World\" | encode utf8 | bytes at 6..<11 | decode"); + assert_eq!(actual.out, "World"); +} diff --git a/crates/nu-protocol/src/pipeline/byte_stream.rs b/crates/nu-protocol/src/pipeline/byte_stream.rs index 2b7c93495e..a2e6b9fd2a 100644 --- a/crates/nu-protocol/src/pipeline/byte_stream.rs +++ b/crates/nu-protocol/src/pipeline/byte_stream.rs @@ -220,6 +220,81 @@ impl ByteStream { ) } + pub fn skip(self, span: Span, n: u64) -> Result { + if let Some(mut reader) = self.reader() { + // Copy the number of skipped bytes into the sink before proceeding + io::copy(&mut (&mut reader).take(n as u64), &mut io::sink()).err_span(span)?; + Ok(ByteStream::read( + reader, + span, + Signals::empty(), + ByteStreamType::Binary, + )) + } else { + return Err(ShellError::TypeMismatch { + err_message: "expected readable stream".into(), + span, + }); + } + } + + pub fn take(self, span: Span, n: u64) -> Result { + if let Some(reader) = self.reader() { + Ok(ByteStream::read( + reader.take(n), + span, + Signals::empty(), + ByteStreamType::Binary, + )) + } else { + return Err(ShellError::TypeMismatch { + err_message: "expected readable stream".into(), + span, + }); + } + } + + pub fn slice( + self, + val_span: Span, + call_span: Span, + start: isize, + end: isize, + ) -> Result { + match self.known_size { + Some(len) => { + let absolute_start = match start { + start if start < 0 => (len as isize + start).max(0) as usize, + start => start.min(len as isize) as usize, + }; + + self.skip(val_span, absolute_start as u64) + .and_then(|stream| { + let absolute_end = match end { + end if end < 0 => (len as isize + end).max(0) as usize, + end => end.min(len as isize) as usize, + }; + + if absolute_end < absolute_start { + stream.take(val_span, 0) + } else { + stream.take(val_span, (absolute_end - absolute_start) as u64) + } + }) + } + None if start < 0 || end < 0 => Err(ShellError::IncorrectValue { + msg: + "Negative range values cannot be used with streams that don't specify a length" + .into(), + val_span, + call_span, + }), + None => self + .skip(val_span, start as u64) + .and_then(|stream| stream.take(val_span, end as u64)), + } + } + /// Create a [`ByteStream`] from a string. The type of the stream is always `String`. pub fn read_string(string: String, span: Span, signals: Signals) -> Self { let len = string.len(); diff --git a/crates/nu-protocol/tests/mod.rs b/crates/nu-protocol/tests/mod.rs new file mode 100644 index 0000000000..eab2e1f414 --- /dev/null +++ b/crates/nu-protocol/tests/mod.rs @@ -0,0 +1 @@ +mod pipeline; diff --git a/crates/nu-protocol/tests/pipeline/byte_stream.rs b/crates/nu-protocol/tests/pipeline/byte_stream.rs new file mode 100644 index 0000000000..6889ead173 --- /dev/null +++ b/crates/nu-protocol/tests/pipeline/byte_stream.rs @@ -0,0 +1,78 @@ +use nu_protocol::{ByteStream, Signals, Span}; + +#[test] +pub fn test_simple_positive_slice() { + let data = b"Hello World".to_vec(); + let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty()); + let sliced = stream + .slice(Span::test_data(), Span::test_data(), 0, 5) + .unwrap(); + let result = sliced.into_bytes().unwrap(); + assert_eq!(result, b"Hello"); +} + +#[test] +pub fn test_negative_start() { + let data = b"Hello World".to_vec(); + let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty()); + let sliced = stream + .slice(Span::test_data(), Span::test_data(), -5, 11) + .unwrap(); + let result = sliced.into_bytes().unwrap(); + assert_eq!(result, b"World"); +} + +#[test] +pub fn test_negative_end() { + let data = b"Hello World".to_vec(); + let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty()); + let sliced = stream + .slice(Span::test_data(), Span::test_data(), 0, -6) + .unwrap(); + let result = sliced.into_bytes().unwrap(); + assert_eq!(result, b"Hello"); +} + +#[test] +pub fn test_empty_slice() { + let data = b"Hello World".to_vec(); + let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty()); + let sliced = stream + .slice(Span::test_data(), Span::test_data(), 5, 5) + .unwrap(); + let result = sliced.into_bytes().unwrap(); + assert_eq!(result, Vec::::new()); +} + +#[test] +pub fn test_out_of_bounds() { + let data = b"Hello World".to_vec(); + let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty()); + let sliced = stream + .slice(Span::test_data(), Span::test_data(), 0, 20) + .unwrap(); + let result = sliced.into_bytes().unwrap(); + assert_eq!(result, b"Hello World"); +} + +#[test] +pub fn test_invalid_range() { + let data = b"Hello World".to_vec(); + let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty()); + let sliced = stream + .slice(Span::test_data(), Span::test_data(), 11, 5) + .unwrap(); + let result = sliced.into_bytes().unwrap(); + assert_eq!(result, Vec::::new()); +} + +#[test] +pub fn test_max_end() { + let data = b"Hello World".to_vec(); + let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty()); + let sliced = stream + .slice(Span::test_data(), Span::test_data(), 6, isize::MAX) + .unwrap(); + let result = sliced.into_bytes().unwrap(); + assert_eq!(result, b"World"); +} diff --git a/crates/nu-protocol/tests/pipeline/mod.rs b/crates/nu-protocol/tests/pipeline/mod.rs new file mode 100644 index 0000000000..a3003e8031 --- /dev/null +++ b/crates/nu-protocol/tests/pipeline/mod.rs @@ -0,0 +1 @@ +mod byte_stream;