add a threads parameter to par_each (#8679)

# Description

This PR allows you to control the amount of threads that `par-each` uses
via a `--threads(-t)` parameter. When no threads parameter is specified,
`par-each` uses the default, which is the same number of available CPUs
on your system.


![image](https://user-images.githubusercontent.com/343840/228935152-eca5b06b-4e8d-41be-82c4-ecd49cdf1fe1.png)

closes #4407

# User-Facing Changes

New parameter

# Tests + Formatting

Don't forget to add tests that cover your changes.

Make sure you've run and fixed any issues with these commands:

- `cargo fmt --all -- --check` to check standard code formatting (`cargo
fmt --all` applies these changes)
- `cargo clippy --workspace -- -D warnings -D clippy::unwrap_used -A
clippy::needless_collect` to check that you're using the standard code
style
- `cargo test --workspace` to check that all tests pass
- `cargo run -- crates/nu-utils/standard_library/tests.nu` to run the
tests for the standard library

> **Note**
> from `nushell` you can also use the `toolkit` as follows
> ```bash
> use toolkit.nu # or use an `env_change` hook to activate it
automatically
> toolkit check pr
> ```

# After Submitting

If your PR had any user-facing changes, update [the
documentation](https://github.com/nushell/nushell.github.io) after the
PR is merged, if necessary. This will help us keep the docs up to date.
This commit is contained in:
Darren Schroeder 2023-03-30 16:39:40 -05:00 committed by GitHub
parent 0e496f900d
commit 09276db2a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -30,6 +30,12 @@ impl Command for ParEach {
), ),
(Type::Table(vec![]), Type::List(Box::new(Type::Any))), (Type::Table(vec![]), Type::List(Box::new(Type::Any))),
]) ])
.named(
"threads",
SyntaxShape::Int,
"the number of threads to use",
Some('t'),
)
.required( .required(
"closure", "closure",
SyntaxShape::Closure(Some(vec![SyntaxShape::Any, SyntaxShape::Int])), SyntaxShape::Closure(Some(vec![SyntaxShape::Any, SyntaxShape::Int])),
@ -85,8 +91,27 @@ impl Command for ParEach {
call: &Call, call: &Call,
input: PipelineData, input: PipelineData,
) -> Result<PipelineData, ShellError> { ) -> Result<PipelineData, ShellError> {
let capture_block: Closure = call.req(engine_state, stack, 0)?; fn create_pool(num_threads: usize) -> Result<rayon::ThreadPool, ShellError> {
match rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
{
Err(e) => Err(e).map_err(|e| {
ShellError::GenericError(
"Error creating thread pool".into(),
e.to_string(),
Some(Span::unknown()),
None,
Vec::new(),
)
}),
Ok(pool) => Ok(pool),
}
}
let capture_block: Closure = call.req(engine_state, stack, 0)?;
let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
let max_threads = threads.unwrap_or(0);
let metadata = input.metadata(); let metadata = input.metadata();
let ctrlc = engine_state.ctrlc.clone(); let ctrlc = engine_state.ctrlc.clone();
let block_id = capture_block.block_id; let block_id = capture_block.block_id;
@ -96,156 +121,165 @@ impl Command for ParEach {
match input { match input {
PipelineData::Empty => Ok(PipelineData::Empty), PipelineData::Empty => Ok(PipelineData::Empty),
PipelineData::Value(Value::Range { val, .. }, ..) => Ok(val PipelineData::Value(Value::Range { val, .. }, ..) => Ok(create_pool(max_threads)?
.into_range_iter(ctrlc.clone())? .install(|| {
.par_bridge() val.into_range_iter(ctrlc.clone())
.map(move |x| { .expect("unable to create a range iterator")
let block = engine_state.get_block(block_id); .par_bridge()
.map(move |x| {
let block = engine_state.get_block(block_id);
let mut stack = stack.clone(); let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) { if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id { if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone()); stack.add_var(*var_id, x.clone());
}
}
let val_span = x.span();
match eval_block_with_early_return(
engine_state,
&mut stack,
block,
x.into_pipeline_data(),
redirect_stdout,
redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(chain_error_with_input(error, val_span)),
}
.into_pipeline_data(),
}
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)
})),
PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(create_pool(max_threads)?
.install(|| {
val.par_iter()
.map(move |x| {
let block = engine_state.get_block(block_id);
let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone());
}
}
let val_span = x.span();
match eval_block_with_early_return(
engine_state,
&mut stack,
block,
x.clone().into_pipeline_data(),
redirect_stdout,
redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(chain_error_with_input(error, val_span)),
}
.into_pipeline_data(),
}
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)
})),
PipelineData::ListStream(stream, ..) => Ok(create_pool(max_threads)?.install(|| {
stream
.par_bridge()
.map(move |x| {
let block = engine_state.get_block(block_id);
let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone());
}
} }
}
let val_span = x.span(); let val_span = x.span();
match eval_block_with_early_return( match eval_block_with_early_return(
engine_state, engine_state,
&mut stack, &mut stack,
block, block,
x.into_pipeline_data(), x.into_pipeline_data(),
redirect_stdout, redirect_stdout,
redirect_stderr, redirect_stderr,
) { ) {
Ok(v) => v, Ok(v) => v,
Err(error) => Value::Error { Err(error) => Value::Error {
error: Box::new(chain_error_with_input(error, val_span)), error: Box::new(chain_error_with_input(error, val_span)),
}
.into_pipeline_data(),
} }
.into_pipeline_data(), })
} .collect::<Vec<_>>()
}) .into_iter()
.collect::<Vec<_>>() .flatten()
.into_iter() .into_pipeline_data(ctrlc)
.flatten() })),
.into_pipeline_data(ctrlc)),
PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(val
.into_iter()
.par_bridge()
.map(move |x| {
let block = engine_state.get_block(block_id);
let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone());
}
}
let val_span = x.span();
match eval_block_with_early_return(
engine_state,
&mut stack,
block,
x.into_pipeline_data(),
redirect_stdout,
redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(chain_error_with_input(error, val_span)),
}
.into_pipeline_data(),
}
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)),
PipelineData::ListStream(stream, ..) => Ok(stream
.par_bridge()
.map(move |x| {
let block = engine_state.get_block(block_id);
let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone());
}
}
let val_span = x.span();
match eval_block_with_early_return(
engine_state,
&mut stack,
block,
x.into_pipeline_data(),
redirect_stdout,
redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(chain_error_with_input(error, val_span)),
}
.into_pipeline_data(),
}
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)),
PipelineData::ExternalStream { stdout: None, .. } => Ok(PipelineData::empty()), PipelineData::ExternalStream { stdout: None, .. } => Ok(PipelineData::empty()),
PipelineData::ExternalStream { PipelineData::ExternalStream {
stdout: Some(stream), stdout: Some(stream),
.. ..
} => Ok(stream } => Ok(create_pool(max_threads)?.install(|| {
.par_bridge() stream
.map(move |x| { .par_bridge()
let x = match x { .map(move |x| {
Ok(x) => x, let x = match x {
Err(err) => { Ok(x) => x,
return Value::Error { Err(err) => {
error: Box::new(err), return Value::Error {
error: Box::new(err),
}
.into_pipeline_data()
}
};
let block = engine_state.get_block(block_id);
let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone());
} }
.into_pipeline_data()
} }
};
let block = engine_state.get_block(block_id); match eval_block_with_early_return(
engine_state,
let mut stack = stack.clone(); &mut stack,
block,
if let Some(var) = block.signature.get_positional(0) { x.into_pipeline_data(),
if let Some(var_id) = &var.var_id { redirect_stdout,
stack.add_var(*var_id, x.clone()); redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(error),
}
.into_pipeline_data(),
} }
} })
.collect::<Vec<_>>()
match eval_block_with_early_return( .into_iter()
engine_state, .flatten()
&mut stack, .into_pipeline_data(ctrlc)
block, })),
x.into_pipeline_data(),
redirect_stdout,
redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(error),
}
.into_pipeline_data(),
}
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)),
// This match allows non-iterables to be accepted, // This match allows non-iterables to be accepted,
// which is currently considered undesirable (Nov 2022). // which is currently considered undesirable (Nov 2022).
PipelineData::Value(x, ..) => { PipelineData::Value(x, ..) => {
eprint!("value");
let block = engine_state.get_block(block_id); let block = engine_state.get_block(block_id);
if let Some(var) = block.signature.get_positional(0) { if let Some(var) = block.signature.get_positional(0) {