Add and use new Signals struct (#13314)

# Description
This PR introduces a new `Signals` struct to replace our adhoc passing
around of `ctrlc: Option<Arc<AtomicBool>>`. Doing so has a few benefits:
- We can better enforce when/where resetting or triggering an interrupt
is allowed.
- Consolidates `nu_utils::ctrl_c::was_pressed` and other ad-hoc
re-implementations into a single place: `Signals::check`.
- This allows us to add other types of signals later if we want. E.g.,
exiting or suspension.
- Similarly, we can more easily change the underlying implementation if
we need to in the future.
- Places that used to have a `ctrlc` of `None` now use
`Signals::empty()`, so we can double check these usages for correctness
in the future.
This commit is contained in:
Ian Manske
2024-07-07 22:29:01 +00:00
committed by GitHub
parent c6b6b1b7a8
commit 399a7c8836
246 changed files with 1332 additions and 1234 deletions

View File

@ -2,13 +2,8 @@ use crate::database::values::sqlite::{open_sqlite_db, values_to_sql};
use nu_engine::command_prelude::*;
use itertools::Itertools;
use std::{
path::Path,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use nu_protocol::Signals;
use std::path::Path;
pub const DEFAULT_TABLE_NAME: &str = "main";
@ -188,23 +183,18 @@ fn operate(
let file_name: Spanned<String> = call.req(engine_state, stack, 0)?;
let table_name: Option<Spanned<String>> = call.get_flag(engine_state, stack, "table-name")?;
let table = Table::new(&file_name, table_name)?;
let ctrl_c = engine_state.ctrlc.clone();
match action(input, table, span, ctrl_c) {
Ok(val) => Ok(val.into_pipeline_data()),
Err(e) => Err(e),
}
Ok(action(input, table, span, engine_state.signals())?.into_pipeline_data())
}
fn action(
input: PipelineData,
table: Table,
span: Span,
ctrl_c: Option<Arc<AtomicBool>>,
signals: &Signals,
) -> Result<Value, ShellError> {
match input {
PipelineData::ListStream(stream, _) => {
insert_in_transaction(stream.into_iter(), span, table, ctrl_c)
insert_in_transaction(stream.into_iter(), span, table, signals)
}
PipelineData::Value(
Value::List {
@ -212,9 +202,9 @@ fn action(
internal_span,
},
_,
) => insert_in_transaction(vals.into_iter(), internal_span, table, ctrl_c),
) => insert_in_transaction(vals.into_iter(), internal_span, table, signals),
PipelineData::Value(val, _) => {
insert_in_transaction(std::iter::once(val), span, table, ctrl_c)
insert_in_transaction(std::iter::once(val), span, table, signals)
}
_ => Err(ShellError::OnlySupportsThisInputType {
exp_input_type: "list".into(),
@ -229,7 +219,7 @@ fn insert_in_transaction(
stream: impl Iterator<Item = Value>,
span: Span,
mut table: Table,
ctrl_c: Option<Arc<AtomicBool>>,
signals: &Signals,
) -> Result<Value, ShellError> {
let mut stream = stream.peekable();
let first_val = match stream.peek() {
@ -251,17 +241,15 @@ fn insert_in_transaction(
let tx = table.try_init(&first_val)?;
for stream_value in stream {
if let Some(ref ctrlc) = ctrl_c {
if ctrlc.load(Ordering::Relaxed) {
tx.rollback().map_err(|e| ShellError::GenericError {
error: "Failed to rollback SQLite transaction".into(),
msg: e.to_string(),
span: None,
help: None,
inner: Vec::new(),
})?;
return Err(ShellError::InterruptedByUser { span: None });
}
if let Err(err) = signals.check(span) {
tx.rollback().map_err(|e| ShellError::GenericError {
error: "Failed to rollback SQLite transaction".into(),
msg: e.to_string(),
span: None,
help: None,
inner: Vec::new(),
})?;
return Err(err);
}
let val = stream_value.as_record()?;

View File

@ -2,7 +2,7 @@ use super::definitions::{
db_column::DbColumn, db_constraint::DbConstraint, db_foreignkey::DbForeignKey,
db_index::DbIndex, db_table::DbTable,
};
use nu_protocol::{CustomValue, PipelineData, Record, ShellError, Span, Spanned, Value};
use nu_protocol::{CustomValue, PipelineData, Record, ShellError, Signals, Span, Spanned, Value};
use rusqlite::{
types::ValueRef, Connection, DatabaseName, Error as SqliteError, OpenFlags, Row, Statement,
ToSql,
@ -12,7 +12,6 @@ use std::{
fs::File,
io::Read,
path::{Path, PathBuf},
sync::{atomic::AtomicBool, Arc},
};
const SQLITE_MAGIC_BYTES: &[u8] = "SQLite format 3\0".as_bytes();
@ -24,25 +23,21 @@ pub struct SQLiteDatabase {
// 1) YAGNI, 2) it's not obvious how cloning a connection could work, 3) state
// management gets tricky quick. Revisit this approach if we find a compelling use case.
pub path: PathBuf,
#[serde(skip)]
#[serde(skip, default = "Signals::empty")]
// this understandably can't be serialized. think that's OK, I'm not aware of a
// reason why a CustomValue would be serialized outside of a plugin
ctrlc: Option<Arc<AtomicBool>>,
signals: Signals,
}
impl SQLiteDatabase {
pub fn new(path: &Path, ctrlc: Option<Arc<AtomicBool>>) -> Self {
pub fn new(path: &Path, signals: Signals) -> Self {
Self {
path: PathBuf::from(path),
ctrlc,
signals,
}
}
pub fn try_from_path(
path: &Path,
span: Span,
ctrlc: Option<Arc<AtomicBool>>,
) -> Result<Self, ShellError> {
pub fn try_from_path(path: &Path, span: Span, signals: Signals) -> Result<Self, ShellError> {
let mut file = File::open(path).map_err(|e| ShellError::ReadingFile {
msg: e.to_string(),
span,
@ -56,7 +51,7 @@ impl SQLiteDatabase {
})
.and_then(|_| {
if buf == SQLITE_MAGIC_BYTES {
Ok(SQLiteDatabase::new(path, ctrlc))
Ok(SQLiteDatabase::new(path, signals))
} else {
Err(ShellError::ReadingFile {
msg: "Not a SQLite file".into(),
@ -72,7 +67,7 @@ impl SQLiteDatabase {
Value::Custom { val, .. } => match val.as_any().downcast_ref::<Self>() {
Some(db) => Ok(Self {
path: db.path.clone(),
ctrlc: db.ctrlc.clone(),
signals: db.signals.clone(),
}),
None => Err(ShellError::CantConvert {
to_type: "database".into(),
@ -107,16 +102,8 @@ impl SQLiteDatabase {
call_span: Span,
) -> Result<Value, ShellError> {
let conn = open_sqlite_db(&self.path, call_span)?;
let stream = run_sql_query(conn, sql, params, self.ctrlc.clone()).map_err(|e| {
ShellError::GenericError {
error: "Failed to query SQLite database".into(),
msg: e.to_string(),
span: Some(sql.span),
help: None,
inner: vec![],
}
})?;
let stream = run_sql_query(conn, sql, params, &self.signals)
.map_err(|e| e.into_shell_error(sql.span, "Failed to query SQLite database"))?;
Ok(stream)
}
@ -352,12 +339,7 @@ impl SQLiteDatabase {
impl CustomValue for SQLiteDatabase {
fn clone_value(&self, span: Span) -> Value {
let cloned = SQLiteDatabase {
path: self.path.clone(),
ctrlc: self.ctrlc.clone(),
};
Value::custom(Box::new(cloned), span)
Value::custom(Box::new(self.clone()), span)
}
fn type_name(&self) -> String {
@ -366,13 +348,8 @@ impl CustomValue for SQLiteDatabase {
fn to_base_value(&self, span: Span) -> Result<Value, ShellError> {
let db = open_sqlite_db(&self.path, span)?;
read_entire_sqlite_db(db, span, self.ctrlc.clone()).map_err(|e| ShellError::GenericError {
error: "Failed to read from SQLite database".into(),
msg: e.to_string(),
span: Some(span),
help: None,
inner: vec![],
})
read_entire_sqlite_db(db, span, &self.signals)
.map_err(|e| e.into_shell_error(span, "Failed to read from SQLite database"))
}
fn as_any(&self) -> &dyn std::any::Any {
@ -396,20 +373,12 @@ impl CustomValue for SQLiteDatabase {
fn follow_path_string(
&self,
_self_span: Span,
_column_name: String,
column_name: String,
path_span: Span,
) -> Result<Value, ShellError> {
let db = open_sqlite_db(&self.path, path_span)?;
read_single_table(db, _column_name, path_span, self.ctrlc.clone()).map_err(|e| {
ShellError::GenericError {
error: "Failed to read from SQLite database".into(),
msg: e.to_string(),
span: Some(path_span),
help: None,
inner: vec![],
}
})
read_single_table(db, column_name, path_span, &self.signals)
.map_err(|e| e.into_shell_error(path_span, "Failed to read from SQLite database"))
}
fn typetag_name(&self) -> &'static str {
@ -426,12 +395,12 @@ pub fn open_sqlite_db(path: &Path, call_span: Span) -> Result<Connection, ShellE
open_connection_in_memory_custom()
} else {
let path = path.to_string_lossy().to_string();
Connection::open(path).map_err(|e| ShellError::GenericError {
Connection::open(path).map_err(|err| ShellError::GenericError {
error: "Failed to open SQLite database".into(),
msg: e.to_string(),
msg: err.to_string(),
span: Some(call_span),
help: None,
inner: vec![],
inner: Vec::new(),
})
}
}
@ -440,11 +409,10 @@ fn run_sql_query(
conn: Connection,
sql: &Spanned<String>,
params: NuSqlParams,
ctrlc: Option<Arc<AtomicBool>>,
) -> Result<Value, SqliteError> {
signals: &Signals,
) -> Result<Value, SqliteOrShellError> {
let stmt = conn.prepare(&sql.item)?;
prepared_statement_to_nu_list(stmt, params, sql.span, ctrlc)
prepared_statement_to_nu_list(stmt, params, sql.span, signals)
}
// This is taken from to text local_into_string but tweaks it a bit so that certain formatting does not happen
@ -534,23 +502,56 @@ pub fn nu_value_to_params(value: Value) -> Result<NuSqlParams, ShellError> {
}
}
#[derive(Debug)]
enum SqliteOrShellError {
SqliteError(SqliteError),
ShellError(ShellError),
}
impl From<SqliteError> for SqliteOrShellError {
fn from(error: SqliteError) -> Self {
Self::SqliteError(error)
}
}
impl From<ShellError> for SqliteOrShellError {
fn from(error: ShellError) -> Self {
Self::ShellError(error)
}
}
impl SqliteOrShellError {
fn into_shell_error(self, span: Span, msg: &str) -> ShellError {
match self {
Self::SqliteError(err) => ShellError::GenericError {
error: msg.into(),
msg: err.to_string(),
span: Some(span),
help: None,
inner: Vec::new(),
},
Self::ShellError(err) => err,
}
}
}
fn read_single_table(
conn: Connection,
table_name: String,
call_span: Span,
ctrlc: Option<Arc<AtomicBool>>,
) -> Result<Value, SqliteError> {
signals: &Signals,
) -> Result<Value, SqliteOrShellError> {
// TODO: Should use params here?
let stmt = conn.prepare(&format!("SELECT * FROM [{table_name}]"))?;
prepared_statement_to_nu_list(stmt, NuSqlParams::default(), call_span, ctrlc)
prepared_statement_to_nu_list(stmt, NuSqlParams::default(), call_span, signals)
}
fn prepared_statement_to_nu_list(
mut stmt: Statement,
params: NuSqlParams,
call_span: Span,
ctrlc: Option<Arc<AtomicBool>>,
) -> Result<Value, SqliteError> {
signals: &Signals,
) -> Result<Value, SqliteOrShellError> {
let column_names = stmt
.column_names()
.into_iter()
@ -576,11 +577,7 @@ fn prepared_statement_to_nu_list(
let mut row_values = vec![];
for row_result in row_results {
if nu_utils::ctrl_c::was_pressed(&ctrlc) {
// return whatever we have so far, let the caller decide whether to use it
return Ok(Value::list(row_values, call_span));
}
signals.check(call_span)?;
if let Ok(row_value) = row_result {
row_values.push(row_value);
}
@ -606,11 +603,7 @@ fn prepared_statement_to_nu_list(
let mut row_values = vec![];
for row_result in row_results {
if nu_utils::ctrl_c::was_pressed(&ctrlc) {
// return whatever we have so far, let the caller decide whether to use it
return Ok(Value::list(row_values, call_span));
}
signals.check(call_span)?;
if let Ok(row_value) = row_result {
row_values.push(row_value);
}
@ -626,8 +619,8 @@ fn prepared_statement_to_nu_list(
fn read_entire_sqlite_db(
conn: Connection,
call_span: Span,
ctrlc: Option<Arc<AtomicBool>>,
) -> Result<Value, SqliteError> {
signals: &Signals,
) -> Result<Value, SqliteOrShellError> {
let mut tables = Record::new();
let mut get_table_names =
@ -638,12 +631,8 @@ fn read_entire_sqlite_db(
let table_name: String = row?;
// TODO: Should use params here?
let table_stmt = conn.prepare(&format!("select * from [{table_name}]"))?;
let rows = prepared_statement_to_nu_list(
table_stmt,
NuSqlParams::default(),
call_span,
ctrlc.clone(),
)?;
let rows =
prepared_statement_to_nu_list(table_stmt, NuSqlParams::default(), call_span, signals)?;
tables.push(table_name, rows);
}
@ -710,7 +699,7 @@ mod test {
#[test]
fn can_read_empty_db() {
let db = open_connection_in_memory().unwrap();
let converted_db = read_entire_sqlite_db(db, Span::test_data(), None).unwrap();
let converted_db = read_entire_sqlite_db(db, Span::test_data(), &Signals::empty()).unwrap();
let expected = Value::test_record(Record::new());
@ -730,7 +719,7 @@ mod test {
[],
)
.unwrap();
let converted_db = read_entire_sqlite_db(db, Span::test_data(), None).unwrap();
let converted_db = read_entire_sqlite_db(db, Span::test_data(), &Signals::empty()).unwrap();
let expected = Value::test_record(record! {
"person" => Value::test_list(vec![]),
@ -759,7 +748,7 @@ mod test {
db.execute("INSERT INTO item (id, name) VALUES (456, 'foo bar')", [])
.unwrap();
let converted_db = read_entire_sqlite_db(db, span, None).unwrap();
let converted_db = read_entire_sqlite_db(db, span, &Signals::empty()).unwrap();
let expected = Value::test_record(record! {
"item" => Value::test_list(