Abstracted away Skip/Take for ByteStream using existing filter code

This commit is contained in:
simon-curtis 2024-11-05 02:22:06 +00:00 committed by Simon Curtis
parent 2a934e75d6
commit 88a785a504
8 changed files with 230 additions and 87 deletions

View File

@ -1,11 +1,9 @@
use itertools::Itertools;
use nu_cmd_base::{
input_handler::{operate, CmdArgument},
util,
};
use nu_engine::command_prelude::*;
use nu_protocol::Range;
use std::io::{Read, Write};
#[derive(Clone)]
pub struct BytesAt;
@ -87,7 +85,15 @@ impl Command for BytesAt {
};
if let PipelineData::ByteStream(stream, metadata) = input {
handle_byte_stream(&args, stream, call, metadata, engine_state)
match stream.slice(
call.head,
call.arguments_span(),
args.indexes.0,
args.indexes.1,
) {
Ok(stream) => Ok(PipelineData::ByteStream(stream, metadata)),
Err(err) => Err(err),
}
} else {
operate(map_value, args, input, call.head, engine_state.signals())
}
@ -104,8 +110,8 @@ impl Command for BytesAt {
},
Example {
description: "Slice out `0x[10 01 13]` from `0x[33 44 55 10 01 13]`",
example: "0x[33 44 55 10 01 13 10] | bytes at 3..6",
result: Some(Value::test_binary(vec![0x10, 0x01, 0x13, 0x10])),
example: "0x[33 44 55 10 01 13] | bytes at 3..6",
result: Some(Value::test_binary(vec![0x10, 0x01, 0x13])),
},
Example {
description: "Extract bytes from the start up to a specific index",
@ -145,7 +151,7 @@ fn map_value(input: &Value, args: &Arguments, head: Span) -> Value {
match input {
Value::Binary { val, .. } => {
let (start, end) = resolve_relative_range(range, val.len());
let iter = val.iter().map(|x| *x);
let iter = val.iter().copied();
let bytes: Vec<u8> = if start > end {
vec![]
@ -170,53 +176,6 @@ fn map_value(input: &Value, args: &Arguments, head: Span) -> Value {
}
}
fn handle_byte_stream(
args: &Arguments,
stream: ByteStream,
call: &Call,
metadata: Option<nu_protocol::PipelineMetadata>,
engine_state: &EngineState,
) -> Result<PipelineData, ShellError> {
match stream.reader() {
Some(reader) => {
let iter = reader.bytes().filter_map(Result::ok);
let Subbytes { 0: start, 1: end } = args.indexes;
let mut iter = if start < 0 || end < 0 {
match iter.try_len() {
Ok(len) => {
let (start, end) = resolve_relative_range(&args.indexes, len);
iter.skip(start).take(end - start + 1)
},
Err(_) => return Err(ShellError::IncorrectValue {
msg: "Negative range values cannot be used with streams that don't specify a length".into(),
val_span: call.head,
call_span: call.arguments_span(),
}),
}
} else {
iter.skip(start as usize).take((end - start) as usize)
};
let stream = ByteStream::from_fn(
call.head,
engine_state.signals().clone(),
ByteStreamType::Binary,
move |buf| match iter.next() {
Some(n) if n > 0 => match buf.write(&[n]) {
Ok(_) => Ok(true),
Err(err) => Err(err.into()),
},
_ => Ok(false),
},
);
Ok(PipelineData::ByteStream(stream, metadata))
}
None => Ok(PipelineData::empty()),
}
}
fn resolve_relative_range(range: &Subbytes, len: usize) -> (usize, usize) {
let start = match range.0 {
start if start < 0 => match len as isize + start {
@ -233,3 +192,14 @@ fn resolve_relative_range(range: &Subbytes, len: usize) -> (usize, usize) {
(start, end)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_examples() {
use crate::test_examples;
test_examples(BytesAt {})
}
}

View File

@ -1,6 +1,4 @@
use nu_engine::command_prelude::*;
use nu_protocol::Signals;
use std::io::{self, Read};
#[derive(Clone)]
pub struct Skip;
@ -96,21 +94,9 @@ impl Command for Skip {
PipelineData::ByteStream(stream, metadata) => {
if stream.type_().is_binary_coercible() {
let span = stream.span();
if let Some(mut reader) = stream.reader() {
// Copy the number of skipped bytes into the sink before proceeding
io::copy(&mut (&mut reader).take(n as u64), &mut io::sink())
.err_span(span)?;
Ok(PipelineData::ByteStream(
ByteStream::read(
reader,
call.head,
Signals::empty(),
ByteStreamType::Binary,
),
metadata,
))
} else {
Ok(PipelineData::Empty)
match stream.skip(span, n as u64) {
Ok(stream) => Ok(PipelineData::ByteStream(stream, metadata)),
Err(err) => Err(err),
}
} else {
Err(ShellError::OnlySupportsThisInputType {

View File

@ -1,6 +1,5 @@
use nu_engine::command_prelude::*;
use nu_protocol::Signals;
use std::io::Read;
#[derive(Clone)]
pub struct Take;
@ -89,19 +88,10 @@ impl Command for Take {
)),
PipelineData::ByteStream(stream, metadata) => {
if stream.type_().is_binary_coercible() {
if let Some(reader) = stream.reader() {
// Just take 'rows' bytes off the stream, mimicking the binary behavior
Ok(PipelineData::ByteStream(
ByteStream::read(
reader.take(rows_desired as u64),
head,
Signals::empty(),
ByteStreamType::Binary,
),
metadata,
))
} else {
Ok(PipelineData::Empty)
let span = stream.span();
match stream.take(span, rows_desired as u64) {
Ok(stream) => Ok(PipelineData::ByteStream(stream, metadata)),
Err(err) => Err(err),
}
} else {
Err(ShellError::OnlySupportsThisInputType {

View File

@ -1,21 +1,63 @@
use nu_test_support::nu;
#[test]
fn returns_error_for_relative_range_on_infinite_stream() {
pub fn returns_error_for_relative_range_on_infinite_stream() {
let actual = nu!("nu --testbin iecho 3 | bytes at ..-3");
assert!(
actual.err.contains(
"Negative range values cannot be used with streams that don't specify a length"
"Relative range values cannot be used with streams that don't specify a length"
),
"Expected error message for negative range with infinite stream"
);
}
#[test]
fn returns_bytes_for_fixed_range_on_infinite_stream() {
pub fn returns_bytes_for_fixed_range_on_infinite_stream() {
let actual = nu!("nu --testbin iecho 3 | bytes at ..10 | decode");
assert_eq!(
actual.out, "33333",
"Expected bytes from index 1 to 10, but got different output"
);
}
#[test]
pub fn test_string_returns_correct_slice_for_simple_positive_slice() {
let actual = nu!("\"Hello World\" | encode utf8 | bytes at ..4 | decode");
assert_eq!(actual.out, "Hello");
}
#[test]
pub fn test_string_returns_correct_slice_for_negative_start() {
let actual = nu!("\"Hello World\" | encode utf8 | bytes at 6..11 | decode");
assert_eq!(actual.out, "World");
}
#[test]
pub fn test_string_returns_correct_slice_for_negative_end() {
let actual = nu!("\"Hello World\" | encode utf8 | bytes at ..-7 | decode");
assert_eq!(actual.out, "Hello");
}
#[test]
pub fn test_string_returns_correct_slice_for_empty_slice() {
let actual = nu!("\"Hello World\" | encode utf8 | bytes at 5..<5 | decode");
assert_eq!(actual.out, "");
}
#[test]
pub fn test_string_returns_correct_slice_for_out_of_bounds() {
let actual = nu!("\"Hello World\" | encode utf8 | bytes at 0..20 | decode");
assert_eq!(actual.out, "Hello World");
}
#[test]
pub fn test_string_returns_correct_slice_for_invalid_range() {
let actual = nu!("\"Hello World\" | encode utf8 | bytes at 11..5 | decode");
assert_eq!(actual.out, "");
}
#[test]
pub fn test_string_returns_correct_slice_for_max_end() {
let actual = nu!("\"Hello World\" | encode utf8 | bytes at 6..<11 | decode");
assert_eq!(actual.out, "World");
}

View File

@ -220,6 +220,81 @@ impl ByteStream {
)
}
pub fn skip(self, span: Span, n: u64) -> Result<Self, ShellError> {
if let Some(mut reader) = self.reader() {
// Copy the number of skipped bytes into the sink before proceeding
io::copy(&mut (&mut reader).take(n as u64), &mut io::sink()).err_span(span)?;
Ok(ByteStream::read(
reader,
span,
Signals::empty(),
ByteStreamType::Binary,
))
} else {
return Err(ShellError::TypeMismatch {
err_message: "expected readable stream".into(),
span,
});
}
}
pub fn take(self, span: Span, n: u64) -> Result<Self, ShellError> {
if let Some(reader) = self.reader() {
Ok(ByteStream::read(
reader.take(n),
span,
Signals::empty(),
ByteStreamType::Binary,
))
} else {
return Err(ShellError::TypeMismatch {
err_message: "expected readable stream".into(),
span,
});
}
}
pub fn slice(
self,
val_span: Span,
call_span: Span,
start: isize,
end: isize,
) -> Result<Self, ShellError> {
match self.known_size {
Some(len) => {
let absolute_start = match start {
start if start < 0 => (len as isize + start).max(0) as usize,
start => start.min(len as isize) as usize,
};
self.skip(val_span, absolute_start as u64)
.and_then(|stream| {
let absolute_end = match end {
end if end < 0 => (len as isize + end).max(0) as usize,
end => end.min(len as isize) as usize,
};
if absolute_end < absolute_start {
stream.take(val_span, 0)
} else {
stream.take(val_span, (absolute_end - absolute_start) as u64)
}
})
}
None if start < 0 || end < 0 => Err(ShellError::IncorrectValue {
msg:
"Negative range values cannot be used with streams that don't specify a length"
.into(),
val_span,
call_span,
}),
None => self
.skip(val_span, start as u64)
.and_then(|stream| stream.take(val_span, end as u64)),
}
}
/// Create a [`ByteStream`] from a string. The type of the stream is always `String`.
pub fn read_string(string: String, span: Span, signals: Signals) -> Self {
let len = string.len();

View File

@ -0,0 +1 @@
mod pipeline;

View File

@ -0,0 +1,78 @@
use nu_protocol::{ByteStream, Signals, Span};
#[test]
pub fn test_simple_positive_slice() {
let data = b"Hello World".to_vec();
let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty());
let sliced = stream
.slice(Span::test_data(), Span::test_data(), 0, 5)
.unwrap();
let result = sliced.into_bytes().unwrap();
assert_eq!(result, b"Hello");
}
#[test]
pub fn test_negative_start() {
let data = b"Hello World".to_vec();
let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty());
let sliced = stream
.slice(Span::test_data(), Span::test_data(), -5, 11)
.unwrap();
let result = sliced.into_bytes().unwrap();
assert_eq!(result, b"World");
}
#[test]
pub fn test_negative_end() {
let data = b"Hello World".to_vec();
let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty());
let sliced = stream
.slice(Span::test_data(), Span::test_data(), 0, -6)
.unwrap();
let result = sliced.into_bytes().unwrap();
assert_eq!(result, b"Hello");
}
#[test]
pub fn test_empty_slice() {
let data = b"Hello World".to_vec();
let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty());
let sliced = stream
.slice(Span::test_data(), Span::test_data(), 5, 5)
.unwrap();
let result = sliced.into_bytes().unwrap();
assert_eq!(result, Vec::<u8>::new());
}
#[test]
pub fn test_out_of_bounds() {
let data = b"Hello World".to_vec();
let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty());
let sliced = stream
.slice(Span::test_data(), Span::test_data(), 0, 20)
.unwrap();
let result = sliced.into_bytes().unwrap();
assert_eq!(result, b"Hello World");
}
#[test]
pub fn test_invalid_range() {
let data = b"Hello World".to_vec();
let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty());
let sliced = stream
.slice(Span::test_data(), Span::test_data(), 11, 5)
.unwrap();
let result = sliced.into_bytes().unwrap();
assert_eq!(result, Vec::<u8>::new());
}
#[test]
pub fn test_max_end() {
let data = b"Hello World".to_vec();
let stream = ByteStream::read_binary(data, Span::test_data(), Signals::empty());
let sliced = stream
.slice(Span::test_data(), Span::test_data(), 6, isize::MAX)
.unwrap();
let result = sliced.into_bytes().unwrap();
assert_eq!(result, b"World");
}

View File

@ -0,0 +1 @@
mod byte_stream;