diff --git a/crates/nu-command/src/bytes/at.rs b/crates/nu-command/src/bytes/at.rs index b9eb76ed74..a5b411de60 100644 --- a/crates/nu-command/src/bytes/at.rs +++ b/crates/nu-command/src/bytes/at.rs @@ -4,8 +4,8 @@ use nu_cmd_base::{ util, }; use nu_engine::command_prelude::*; -use nu_protocol::{Range, Reader}; -use std::io::{Bytes, Read, Write}; +use nu_protocol::Range; +use std::io::{Read, Write}; #[derive(Clone)] pub struct BytesAt; @@ -72,7 +72,7 @@ impl Command for BytesAt { input: PipelineData, ) -> Result { let range: Range = call.req(engine_state, stack, 0)?; - let indexes = match util::process_range(&range) { + let indexes: Subbytes = match util::process_range(&range) { Ok(idxs) => idxs.into(), Err(processing_error) => { return Err(processing_error("could not perform subbytes", call.head)); @@ -89,7 +89,7 @@ impl Command for BytesAt { if let PipelineData::ByteStream(stream, metadata) = input { handle_byte_stream(&args, stream, call, metadata, engine_state) } else { - operate(action, args, input, call.head, engine_state.signals()) + operate(map_value, args, input, call.head, engine_state.signals()) } } @@ -104,7 +104,7 @@ impl Command for BytesAt { }, Example { description: "Slice out `0x[10 01 13]` from `0x[33 44 55 10 01 13]`", - example: "0x[33 44 55 10 01 13] | bytes at 3..6", + example: "0x[33 44 55 10 01 13 10] | bytes at 3..6", result: Some(Value::test_binary(vec![0x10, 0x01, 0x13, 0x10])), }, Example { @@ -140,10 +140,23 @@ impl Command for BytesAt { } } -fn action(input: &Value, args: &Arguments, head: Span) -> Value { +fn map_value(input: &Value, args: &Arguments, head: Span) -> Value { let range = &args.indexes; match input { - Value::Binary { val, .. } => read_bytes(val, range, head), + Value::Binary { val, .. } => { + let (start, end) = resolve_relative_range(range, val.len()); + let iter = val.iter().map(|x| *x); + + let bytes: Vec = if start > end { + vec![] + } else if end == usize::MAX { + iter.skip(start).collect() + } else { + iter.skip(start).take(end - start + 1).collect() + }; + + Value::binary(bytes, head) + } Value::Error { .. } => input.clone(), other => Value::error( ShellError::UnsupportedInput { @@ -164,76 +177,59 @@ fn handle_byte_stream( metadata: Option, engine_state: &EngineState, ) -> Result { - let idxs = args.indexes; match stream.reader() { Some(reader) => { - let iter = reader.bytes(); + let iter = reader.bytes().filter_map(Result::ok); + let Subbytes { 0: start, 1: end } = args.indexes; - if idxs.0 < 0 || idxs.1 < 0 { + let mut iter = if start < 0 || end < 0 { match iter.try_len() { - Ok(_) => { - let vec = iter.filter_map(Result::ok).collect::>(); - Ok(read_bytes(&vec, &idxs, call.head).into_pipeline_data_with_metadata(metadata)) - } - _ => Err(ShellError::IncorrectValue { - msg: - "Negative range values cannot be used with streams that don't specify a length" - .into(), + Ok(len) => { + let (start, end) = resolve_relative_range(&args.indexes, len); + iter.skip(start).take(end - start + 1) + }, + Err(_) => return Err(ShellError::IncorrectValue { + msg: "Negative range values cannot be used with streams that don't specify a length".into(), val_span: call.head, call_span: call.arguments_span(), }), } } else { - Ok(read_stream(iter, idxs, call, engine_state, metadata)) - } + iter.skip(start as usize).take((end - start) as usize) + }; + + let stream = ByteStream::from_fn( + call.head, + engine_state.signals().clone(), + ByteStreamType::Binary, + move |buf| match iter.next() { + Some(n) if n > 0 => match buf.write(&[n]) { + Ok(_) => Ok(true), + Err(err) => Err(err.into()), + }, + _ => Ok(false), + }, + ); + + Ok(PipelineData::ByteStream(stream, metadata)) } None => Ok(PipelineData::empty()), } } -fn read_bytes(val: &[u8], range: &Subbytes, head: Span) -> Value { - let len = val.len() as isize; - let start = if range.0 < 0 { range.0 + len } else { range.0 }; - let end = if range.1 < 0 { range.1 + len } else { range.1 }; - - if start > end { - Value::binary(vec![], head) - } else { - let val_iter = val.iter().skip(start as usize); - Value::binary( - if end == isize::MAX { - val_iter.copied().collect::>() - } else { - val_iter.take((end - start + 1) as usize).copied().collect() - }, - head, - ) - } -} - -fn read_stream( - iter: Bytes, - range: Subbytes, - call: &Call, - engine_state: &EngineState, - metadata: Option, -) -> PipelineData { - let start = range.0 as usize; - let end = (range.1 - range.0) as usize; - let mut iter = iter.skip(start).take(end); - - let stream = ByteStream::from_fn( - call.head, - engine_state.signals().clone(), - ByteStreamType::Binary, - move |buf| match iter.next() { - Some(Ok(n)) if n > 0 => match buf.write(&[n]) { - Ok(_) => Ok(true), - Err(err) => Err(err.into()), - }, - _ => Ok(false), +fn resolve_relative_range(range: &Subbytes, len: usize) -> (usize, usize) { + let start = match range.0 { + start if start < 0 => match len as isize + start { + start if start < 0 => 0, + start => start as usize, }, - ); + start => start as usize, + }; - PipelineData::ByteStream(stream, metadata) + let end = match range.1 { + end if end < 0 => (len as isize + end) as usize, + end => end as usize, + }; + + (start, end) }