feat: support storing, syncing and executing scripts (#2644)

* feat: add atuin-scripts crate

* initial

* define record types

* wip

* wip

* mvp

* add show command, make stdin work

* rewrite execution to use shebang and script file ALWAYS

* rename show -> get, allow fetching script only

* fmt

* clippy

* a bunch of fixes to the edits

* update lock

* variables

* fmt

* clippy

* pr feedback

* fmt
This commit is contained in:
Ellie Huxtable
2025-04-07 14:17:19 +01:00
committed by GitHub
parent 2935a5a6bd
commit f162d641a7
22 changed files with 1844 additions and 1 deletions

View File

@ -0,0 +1,33 @@
[package]
name = "atuin-scripts"
edition = "2024"
version = { workspace = true }
description = "The scripts crate for Atuin"
authors.workspace = true
rust-version.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
readme.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
atuin-client = { path = "../atuin-client", version = "18.5.0-beta.1" }
atuin-common = { path = "../atuin-common", version = "18.5.0-beta.1" }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
rmp = { version = "0.8.14" }
uuid = { workspace = true }
eyre = { workspace = true }
tokio = { workspace = true }
serde = { workspace = true }
typed-builder = { workspace = true }
pretty_assertions = { workspace = true }
sql-builder = { workspace = true }
sqlx = { workspace = true }
tempfile = { workspace = true }
minijinja = { workspace = true }
serde_json = { workspace = true }

View File

@ -0,0 +1,2 @@
DROP TABLE scripts;
DROP TABLE script_tags;

View File

@ -0,0 +1,17 @@
-- Add up migration script here
CREATE TABLE scripts (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT NOT NULL,
shebang TEXT NOT NULL,
script TEXT NOT NULL,
inserted_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
);
CREATE TABLE script_tags (
id INTEGER PRIMARY KEY,
script_id TEXT NOT NULL,
tag TEXT NOT NULL
);
CREATE UNIQUE INDEX idx_script_tags ON script_tags (script_id, tag);

View File

@ -0,0 +1,2 @@
-- Add down migration script here
alter table scripts drop index name_uniq_idx;

View File

@ -0,0 +1,2 @@
-- Add up migration script here
create unique index name_uniq_idx ON scripts(name);

View File

@ -0,0 +1,358 @@
use std::{path::Path, str::FromStr, time::Duration};
use atuin_common::utils;
use sqlx::{
Result, Row,
sqlite::{
SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
SqliteSynchronous,
},
};
use tokio::fs;
use tracing::debug;
use uuid::Uuid;
use crate::store::script::Script;
#[derive(Debug, Clone)]
pub struct Database {
pub pool: SqlitePool,
}
impl Database {
pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
let path = path.as_ref();
debug!("opening script sqlite database at {:?}", path);
if utils::broken_symlink(path) {
eprintln!(
"Atuin: Script sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
);
std::process::exit(1);
}
if !path.exists() {
if let Some(dir) = path.parent() {
fs::create_dir_all(dir).await?;
}
}
let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
.journal_mode(SqliteJournalMode::Wal)
.optimize_on_close(true, None)
.synchronous(SqliteSynchronous::Normal)
.with_regexp()
.foreign_keys(true)
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.acquire_timeout(Duration::from_secs_f64(timeout))
.connect_with(opts)
.await?;
Self::setup_db(&pool).await?;
Ok(Self { pool })
}
pub async fn sqlite_version(&self) -> Result<String> {
sqlx::query_scalar("SELECT sqlite_version()")
.fetch_one(&self.pool)
.await
}
async fn setup_db(pool: &SqlitePool) -> Result<()> {
debug!("running sqlite database setup");
sqlx::migrate!("./migrations").run(pool).await?;
Ok(())
}
async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, s: &Script) -> Result<()> {
sqlx::query(
"insert or ignore into scripts(id, name, description, shebang, script)
values(?1, ?2, ?3, ?4, ?5)",
)
.bind(s.id.to_string())
.bind(s.name.as_str())
.bind(s.description.as_str())
.bind(s.shebang.as_str())
.bind(s.script.as_str())
.execute(&mut **tx)
.await?;
for tag in s.tags.iter() {
sqlx::query(
"insert or ignore into script_tags(script_id, tag)
values(?1, ?2)",
)
.bind(s.id.to_string())
.bind(tag)
.execute(&mut **tx)
.await?;
}
Ok(())
}
pub async fn save(&self, s: &Script) -> Result<()> {
debug!("saving script to sqlite");
let mut tx = self.pool.begin().await?;
Self::save_raw(&mut tx, s).await?;
tx.commit().await?;
Ok(())
}
pub async fn save_bulk(&self, s: &[Script]) -> Result<()> {
debug!("saving scripts to sqlite");
let mut tx = self.pool.begin().await?;
for i in s {
Self::save_raw(&mut tx, i).await?;
}
tx.commit().await?;
Ok(())
}
fn query_script(row: SqliteRow) -> Script {
let id = row.get("id");
let name = row.get("name");
let description = row.get("description");
let shebang = row.get("shebang");
let script = row.get("script");
let id = Uuid::parse_str(id).unwrap();
Script {
id,
name,
description,
shebang,
script,
tags: vec![],
}
}
fn query_script_tags(row: SqliteRow) -> String {
row.get("tag")
}
#[allow(dead_code)]
async fn load(&self, id: &str) -> Result<Option<Script>> {
debug!("loading script item {}", id);
let res = sqlx::query("select * from scripts where id = ?1")
.bind(id)
.map(Self::query_script)
.fetch_optional(&self.pool)
.await?;
// intentionally not joining, don't want to duplicate the script data in memory a whole bunch.
if let Some(mut script) = res {
let tags = sqlx::query("select tag from script_tags where script_id = ?1")
.bind(id)
.map(Self::query_script_tags)
.fetch_all(&self.pool)
.await?;
script.tags = tags;
Ok(Some(script))
} else {
Ok(None)
}
}
pub async fn list(&self) -> Result<Vec<Script>> {
debug!("listing scripts");
let mut res = sqlx::query("select * from scripts")
.map(Self::query_script)
.fetch_all(&self.pool)
.await?;
// Fetch all the tags for each script
for script in res.iter_mut() {
let tags = sqlx::query("select tag from script_tags where script_id = ?1")
.bind(script.id.to_string())
.map(Self::query_script_tags)
.fetch_all(&self.pool)
.await?;
script.tags = tags;
}
Ok(res)
}
pub async fn delete(&self, id: &str) -> Result<()> {
debug!("deleting script {}", id);
sqlx::query("delete from scripts where id = ?1")
.bind(id)
.execute(&self.pool)
.await?;
// delete all the tags for the script
sqlx::query("delete from script_tags where script_id = ?1")
.bind(id)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn update(&self, s: &Script) -> Result<()> {
debug!("updating script {:?}", s);
let mut tx = self.pool.begin().await?;
// Update the script's base fields
sqlx::query("update scripts set name = ?1, description = ?2, shebang = ?3, script = ?4 where id = ?5")
.bind(s.name.as_str())
.bind(s.description.as_str())
.bind(s.shebang.as_str())
.bind(s.script.as_str())
.bind(s.id.to_string())
.execute(&mut *tx)
.await?;
// Delete all existing tags for this script
sqlx::query("delete from script_tags where script_id = ?1")
.bind(s.id.to_string())
.execute(&mut *tx)
.await?;
// Insert new tags
for tag in s.tags.iter() {
sqlx::query(
"insert or ignore into script_tags(script_id, tag)
values(?1, ?2)",
)
.bind(s.id.to_string())
.bind(tag)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
let res = sqlx::query("select * from scripts where name = ?1")
.bind(name)
.map(Self::query_script)
.fetch_optional(&self.pool)
.await?;
let script = if let Some(mut script) = res {
let tags = sqlx::query("select tag from script_tags where script_id = ?1")
.bind(script.id.to_string())
.map(Self::query_script_tags)
.fetch_all(&self.pool)
.await?;
script.tags = tags;
Some(script)
} else {
None
};
Ok(script)
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn test_list() {
let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
let scripts = db.list().await.unwrap();
assert_eq!(scripts.len(), 0);
let script = Script::builder()
.name("test".to_string())
.description("test".to_string())
.shebang("test".to_string())
.script("test".to_string())
.build();
db.save(&script).await.unwrap();
let scripts = db.list().await.unwrap();
assert_eq!(scripts.len(), 1);
assert_eq!(scripts[0].name, "test");
}
#[tokio::test]
async fn test_save_load() {
let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
let script = Script::builder()
.name("test name".to_string())
.description("test description".to_string())
.shebang("test shebang".to_string())
.script("test script".to_string())
.build();
db.save(&script).await.unwrap();
let loaded = db.load(&script.id.to_string()).await.unwrap().unwrap();
assert_eq!(loaded, script);
}
#[tokio::test]
async fn test_save_bulk() {
let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
let scripts = vec![
Script::builder()
.name("test name".to_string())
.description("test description".to_string())
.shebang("test shebang".to_string())
.script("test script".to_string())
.build(),
Script::builder()
.name("test name 2".to_string())
.description("test description 2".to_string())
.shebang("test shebang 2".to_string())
.script("test script 2".to_string())
.build(),
];
db.save_bulk(&scripts).await.unwrap();
let loaded = db.list().await.unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0].name, "test name");
assert_eq!(loaded[1].name, "test name 2");
}
#[tokio::test]
async fn test_delete() {
let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
let script = Script::builder()
.name("test name".to_string())
.description("test description".to_string())
.shebang("test shebang".to_string())
.script("test script".to_string())
.build();
db.save(&script).await.unwrap();
assert_eq!(db.list().await.unwrap().len(), 1);
db.delete(&script.id.to_string()).await.unwrap();
let loaded = db.list().await.unwrap();
assert_eq!(loaded.len(), 0);
}
}

View File

@ -0,0 +1,287 @@
use crate::store::script::Script;
use eyre::Result;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::process::Stdio;
use tempfile::NamedTempFile;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::sync::mpsc;
use tokio::task;
use tracing::debug;
// Helper function to build a complete script with shebang
pub fn build_executable_script(script: String, shebang: String) -> String {
if shebang.is_empty() {
// Default to bash if no shebang is provided
format!("#!/usr/bin/env bash\n{}", script)
} else if script.starts_with("#!") {
format!("{}\n{}", shebang, script)
} else {
format!("#!{}\n{}", shebang, script)
}
}
/// Represents the communication channels for an interactive script
pub struct ScriptSession {
/// Channel to send input to the script
pub stdin_tx: mpsc::Sender<String>,
/// Exit code of the process once it completes
pub exit_code_rx: mpsc::Receiver<i32>,
}
impl ScriptSession {
/// Send input to the running script
pub async fn send_input(&self, input: String) -> Result<(), mpsc::error::SendError<String>> {
self.stdin_tx.send(input).await
}
/// Wait for the script to complete and get the exit code
pub async fn wait_for_exit(&mut self) -> Option<i32> {
self.exit_code_rx.recv().await
}
}
fn setup_template(script: &Script) -> Result<minijinja::Environment> {
let mut env = minijinja::Environment::new();
env.set_trim_blocks(true);
env.add_template("script", script.script.as_str())?;
Ok(env)
}
/// Template a script with the given context
pub fn template_script(
script: &Script,
context: &HashMap<String, serde_json::Value>,
) -> Result<String> {
let env = setup_template(script)?;
let template = env.get_template("script")?;
let rendered = template.render(context)?;
Ok(rendered)
}
/// Get the variables that need to be templated in a script
pub fn template_variables(script: &Script) -> Result<HashSet<String>> {
let env = setup_template(script)?;
let template = env.get_template("script")?;
Ok(template.undeclared_variables(true))
}
/// Execute a script interactively, allowing for ongoing stdin/stdout interaction
pub async fn execute_script_interactive(
script: String,
shebang: String,
) -> Result<ScriptSession, Box<dyn std::error::Error + Send + Sync>> {
// Create a temporary file for the script
let temp_file = NamedTempFile::new()?;
let temp_path = temp_file.path().to_path_buf();
debug!("creating temp file at {}", temp_path.display());
// Extract interpreter from shebang for fallback execution
let interpreter = if !shebang.is_empty() {
shebang.trim_start_matches("#!").trim().to_string()
} else {
"/usr/bin/env bash".to_string()
};
// Write script content to the temp file, including the shebang
let full_script_content = build_executable_script(script.clone(), shebang.clone());
debug!("writing script content to temp file");
tokio::fs::write(&temp_path, &full_script_content).await?;
// Make it executable on Unix systems
#[cfg(unix)]
{
debug!("making script executable");
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&temp_path)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&temp_path, perms)?;
}
// Store the temp_file to prevent it from being dropped
// This ensures it won't be deleted while the script is running
let _keep_temp_file = temp_file;
debug!("attempting direct script execution");
let mut child_result = tokio::process::Command::new(temp_path.to_str().unwrap())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn();
// If direct execution fails, try using the interpreter
if let Err(e) = &child_result {
debug!("direct execution failed: {}, trying with interpreter", e);
// When falling back to interpreter, remove the shebang from the file
// Some interpreters don't handle scripts with shebangs well
debug!("writing script content without shebang for interpreter execution");
tokio::fs::write(&temp_path, &script).await?;
// Parse the interpreter command
let parts: Vec<&str> = interpreter.split_whitespace().collect();
if !parts.is_empty() {
let mut cmd = tokio::process::Command::new(parts[0]);
// Add any interpreter args
for i in parts.iter().skip(1) {
cmd.arg(i);
}
// Add the script path
cmd.arg(temp_path.to_str().unwrap());
// Try with the interpreter
child_result = cmd
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn();
}
}
// If it still fails, return the error
let mut child = match child_result {
Ok(child) => child,
Err(e) => {
return Err(format!("Failed to execute script: {}", e).into());
}
};
// Get handles to stdin, stdout, stderr
let mut stdin = child
.stdin
.take()
.ok_or_else(|| "Failed to open child process stdin".to_string())?;
let stdout = child
.stdout
.take()
.ok_or_else(|| "Failed to open child process stdout".to_string())?;
let stderr = child
.stderr
.take()
.ok_or_else(|| "Failed to open child process stderr".to_string())?;
// Create channels for the interactive session
let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(32);
let (exit_code_tx, exit_code_rx) = mpsc::channel::<i32>(1);
// handle user stdin
debug!("spawning stdin handler");
tokio::spawn(async move {
while let Some(input) = stdin_rx.recv().await {
if let Err(e) = stdin.write_all(input.as_bytes()).await {
eprintln!("Error writing to stdin: {}", e);
break;
}
if let Err(e) = stdin.flush().await {
eprintln!("Error flushing stdin: {}", e);
break;
}
}
// when the channel closes (sender dropped), we let stdin close naturally
});
// handle stdout
debug!("spawning stdout handler");
let stdout_handle = task::spawn(async move {
let mut stdout_reader = BufReader::new(stdout);
let mut buffer = [0u8; 1024];
let mut stdout_writer = tokio::io::stdout();
loop {
match stdout_reader.read(&mut buffer).await {
Ok(0) => break, // End of stdout
Ok(n) => {
if let Err(e) = stdout_writer.write_all(&buffer[0..n]).await {
eprintln!("Error writing to stdout: {}", e);
break;
}
if let Err(e) = stdout_writer.flush().await {
eprintln!("Error flushing stdout: {}", e);
break;
}
}
Err(e) => {
eprintln!("Error reading from process stdout: {}", e);
break;
}
}
}
});
// Process stderr in a separate task
debug!("spawning stderr handler");
let stderr_handle = task::spawn(async move {
let mut stderr_reader = BufReader::new(stderr);
let mut buffer = [0u8; 1024];
let mut stderr_writer = tokio::io::stderr();
loop {
match stderr_reader.read(&mut buffer).await {
Ok(0) => break, // End of stderr
Ok(n) => {
if let Err(e) = stderr_writer.write_all(&buffer[0..n]).await {
eprintln!("Error writing to stderr: {}", e);
break;
}
if let Err(e) = stderr_writer.flush().await {
eprintln!("Error flushing stderr: {}", e);
break;
}
}
Err(e) => {
eprintln!("Error reading from process stderr: {}", e);
break;
}
}
}
});
// Spawn a task to wait for the child process to complete
debug!("spawning exit code handler");
let _keep_temp_file_clone = _keep_temp_file;
tokio::spawn(async move {
// Keep the temp file alive until the process completes
let _temp_file_ref = _keep_temp_file_clone;
// Wait for the child process to complete
let status = match child.wait().await {
Ok(status) => {
debug!("Process exited with status: {:?}", status);
status
}
Err(e) => {
eprintln!("Error waiting for child process: {}", e);
// Send a default error code
let _ = exit_code_tx.send(-1).await;
return;
}
};
// Wait for stdout/stderr tasks to complete
if let Err(e) = stdout_handle.await {
eprintln!("Error joining stdout task: {}", e);
}
if let Err(e) = stderr_handle.await {
eprintln!("Error joining stderr task: {}", e);
}
// Send the exit code
let exit_code = status.code().unwrap_or(-1);
debug!("Sending exit code: {}", exit_code);
let _ = exit_code_tx.send(exit_code).await;
});
// Return the communication channels as a ScriptSession
Ok(ScriptSession {
stdin_tx,
exit_code_rx,
})
}

View File

@ -0,0 +1,4 @@
pub mod database;
pub mod execution;
pub mod settings;
pub mod store;

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,109 @@
use eyre::{Result, bail};
use atuin_client::record::sqlite_store::SqliteStore;
use atuin_client::record::{encryption::PASETO_V4, store::Store};
use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx};
use record::ScriptRecord;
use script::{SCRIPT_TAG, SCRIPT_VERSION, Script};
use crate::database::Database;
pub mod record;
pub mod script;
#[derive(Debug, Clone)]
pub struct ScriptStore {
pub store: SqliteStore,
pub host_id: HostId,
pub encryption_key: [u8; 32],
}
impl ScriptStore {
pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self {
ScriptStore {
store,
host_id,
encryption_key,
}
}
async fn push_record(&self, record: ScriptRecord) -> Result<(RecordId, RecordIdx)> {
let bytes = record.serialize()?;
let idx = self
.store
.last(self.host_id, SCRIPT_TAG)
.await?
.map_or(0, |p| p.idx + 1);
let record = Record::builder()
.host(Host::new(self.host_id))
.version(SCRIPT_VERSION.to_string())
.tag(SCRIPT_TAG.to_string())
.idx(idx)
.data(bytes)
.build();
let id = record.id;
self.store
.push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
.await?;
Ok((id, idx))
}
pub async fn create(&self, script: Script) -> Result<()> {
let record = ScriptRecord::Create(script);
self.push_record(record).await?;
Ok(())
}
pub async fn update(&self, script: Script) -> Result<()> {
let record = ScriptRecord::Update(script);
self.push_record(record).await?;
Ok(())
}
pub async fn delete(&self, script_id: uuid::Uuid) -> Result<()> {
let record = ScriptRecord::Delete(script_id);
self.push_record(record).await?;
Ok(())
}
pub async fn scripts(&self) -> Result<Vec<ScriptRecord>> {
let records = self.store.all_tagged(SCRIPT_TAG).await?;
let mut ret = Vec::with_capacity(records.len());
for record in records.into_iter() {
let script = match record.version.as_str() {
SCRIPT_VERSION => {
let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?;
ScriptRecord::deserialize(&decrypted.data, SCRIPT_VERSION)
}
version => bail!("unknown history version {version:?}"),
}?;
ret.push(script);
}
Ok(ret)
}
pub async fn build(&self, database: Database) -> Result<()> {
// Get all the scripts from the database - they are already sorted by timestamp
let scripts = self.scripts().await?;
for script in scripts {
match script {
ScriptRecord::Create(script) => {
database.save(&script).await?;
}
ScriptRecord::Update(script) => database.update(&script).await?,
ScriptRecord::Delete(id) => database.delete(&id.to_string()).await?,
}
}
Ok(())
}
}

View File

@ -0,0 +1,215 @@
use atuin_common::record::DecryptedData;
use eyre::{Result, eyre};
use uuid::Uuid;
use crate::store::script::SCRIPT_VERSION;
use super::script::Script;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScriptRecord {
Create(Script),
Update(Script),
Delete(Uuid),
}
impl ScriptRecord {
pub fn serialize(&self) -> Result<DecryptedData> {
use rmp::encode;
let mut output = vec![];
match self {
ScriptRecord::Create(script) => {
// 0 -> a script create
encode::write_u8(&mut output, 0)?;
let bytes = script.serialize()?;
encode::write_bin(&mut output, &bytes.0)?;
}
ScriptRecord::Delete(id) => {
// 1 -> a script delete
encode::write_u8(&mut output, 1)?;
encode::write_str(&mut output, id.to_string().as_str())?;
}
ScriptRecord::Update(script) => {
// 2 -> a script update
encode::write_u8(&mut output, 2)?;
let bytes = script.serialize()?;
encode::write_bin(&mut output, &bytes.0)?;
}
};
Ok(DecryptedData(output))
}
pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
use rmp::decode;
fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
eyre!("{err:?}")
}
match version {
SCRIPT_VERSION => {
let mut bytes = decode::Bytes::new(&data.0);
let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
match record_type {
// create
0 => {
// written by encode::write_bin above
let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
let script = Script::deserialize(bytes.remaining_slice())?;
Ok(ScriptRecord::Create(script))
}
// delete
1 => {
let bytes = bytes.remaining_slice();
let (id, _) = decode::read_str_from_slice(bytes).map_err(error_report)?;
Ok(ScriptRecord::Delete(Uuid::parse_str(id)?))
}
// update
2 => {
// written by encode::write_bin above
let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
let script = Script::deserialize(bytes.remaining_slice())?;
Ok(ScriptRecord::Update(script))
}
_ => Err(eyre!("unknown script record type {record_type}")),
}
}
_ => Err(eyre!("unknown version {version:?}")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_create() {
let script = Script::builder()
.id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap())
.name("test".to_string())
.description("test".to_string())
.shebang("test".to_string())
.tags(vec!["test".to_string()])
.script("test".to_string())
.build();
let record = ScriptRecord::Create(script);
let serialized = record.serialize().unwrap();
assert_eq!(
serialized.0,
vec![
204, 0, 196, 65, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102,
45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99,
98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115,
116, 145, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116
]
);
}
#[test]
fn test_serialize_delete() {
let record = ScriptRecord::Delete(
uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
);
let serialized = record.serialize().unwrap();
assert_eq!(
serialized.0,
vec![
204, 1, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57,
56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54
]
);
}
#[test]
fn test_serialize_update() {
let script = Script::builder()
.id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap())
.name(String::from("test"))
.description(String::from("test"))
.shebang(String::from("test"))
.tags(vec![String::from("test"), String::from("test2")])
.script(String::from("test"))
.build();
let record = ScriptRecord::Update(script);
let serialized = record.serialize().unwrap();
assert_eq!(
serialized.0,
vec![
204, 2, 196, 71, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102,
45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99,
98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115,
116, 146, 164, 116, 101, 115, 116, 165, 116, 101, 115, 116, 50, 164, 116, 101, 115,
116
],
);
}
#[test]
fn test_serialize_deserialize_create() {
let script = Script::builder()
.name("test".to_string())
.description("test".to_string())
.shebang("test".to_string())
.tags(vec!["test".to_string()])
.script("test".to_string())
.build();
let record = ScriptRecord::Create(script);
let serialized = record.serialize().unwrap();
let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
assert_eq!(record, deserialized);
}
#[test]
fn test_serialize_deserialize_delete() {
let record = ScriptRecord::Delete(
uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
);
let serialized = record.serialize().unwrap();
let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
assert_eq!(record, deserialized);
}
#[test]
fn test_serialize_deserialize_update() {
let script = Script::builder()
.name("test".to_string())
.description("test".to_string())
.shebang("test".to_string())
.tags(vec!["test".to_string()])
.script("test".to_string())
.build();
let record = ScriptRecord::Update(script);
let serialized = record.serialize().unwrap();
let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap();
assert_eq!(record, deserialized);
}
}

View File

@ -0,0 +1,151 @@
use atuin_common::record::DecryptedData;
use eyre::{Result, bail, ensure};
use uuid::Uuid;
use rmp::{
decode::{self, Bytes},
encode,
};
use typed_builder::TypedBuilder;
pub const SCRIPT_VERSION: &str = "v0";
pub const SCRIPT_TAG: &str = "script";
pub const SCRIPT_LEN: usize = 20000; // 20kb max total len
#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)]
/// A script is a set of commands that can be run, with the specified shebang
pub struct Script {
/// The id of the script
#[builder(default = uuid::Uuid::new_v4())]
pub id: Uuid,
/// The name of the script
pub name: String,
/// The description of the script
#[builder(default = String::new())]
pub description: String,
/// The interpreter of the script
#[builder(default = String::new())]
pub shebang: String,
/// The tags of the script
#[builder(default = Vec::new())]
pub tags: Vec<String>,
/// The script content
pub script: String,
}
impl Script {
pub fn serialize(&self) -> Result<DecryptedData> {
// sort the tags first, to ensure consistent ordering
let mut tags = self.tags.clone();
tags.sort();
let mut output = vec![];
encode::write_array_len(&mut output, 6)?;
encode::write_str(&mut output, &self.id.to_string())?;
encode::write_str(&mut output, &self.name)?;
encode::write_str(&mut output, &self.description)?;
encode::write_str(&mut output, &self.shebang)?;
encode::write_array_len(&mut output, self.tags.len() as u32)?;
for tag in &tags {
encode::write_str(&mut output, tag)?;
}
encode::write_str(&mut output, &self.script)?;
Ok(DecryptedData(output))
}
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
let mut bytes = decode::Bytes::new(bytes);
let nfields = decode::read_array_len(&mut bytes).unwrap();
ensure!(nfields == 6, "too many entries in v0 script record");
let bytes = bytes.remaining_slice();
let (id, bytes) = decode::read_str_from_slice(bytes).unwrap();
let (name, bytes) = decode::read_str_from_slice(bytes).unwrap();
let (description, bytes) = decode::read_str_from_slice(bytes).unwrap();
let (shebang, bytes) = decode::read_str_from_slice(bytes).unwrap();
let mut bytes = Bytes::new(bytes);
let tags_len = decode::read_array_len(&mut bytes).unwrap();
let mut bytes = bytes.remaining_slice();
let mut tags = Vec::new();
for _ in 0..tags_len {
let (tag, remaining) = decode::read_str_from_slice(bytes).unwrap();
tags.push(tag.to_owned());
bytes = remaining;
}
let (script, bytes) = decode::read_str_from_slice(bytes).unwrap();
if !bytes.is_empty() {
bail!("trailing bytes in encoded script record. malformed")
}
Ok(Script {
id: Uuid::parse_str(id).unwrap(),
name: name.to_owned(),
description: description.to_owned(),
shebang: shebang.to_owned(),
tags,
script: script.to_owned(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize() {
let script = Script {
id: uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(),
name: "test".to_string(),
description: "test".to_string(),
shebang: "test".to_string(),
tags: vec!["test".to_string()],
script: "test".to_string(),
};
let serialized = script.serialize().unwrap();
assert_eq!(
serialized.0,
vec![
150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57, 56,
50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54, 164,
116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 145, 164,
116, 101, 115, 116, 164, 116, 101, 115, 116
]
);
}
#[test]
fn test_serialize_deserialize() {
let script = Script {
id: uuid::Uuid::new_v4(),
name: "test".to_string(),
description: "test".to_string(),
shebang: "test".to_string(),
tags: vec!["test".to_string()],
script: "test".to_string(),
};
let serialized = script.serialize().unwrap();
let deserialized = Script::deserialize(&serialized.0).unwrap();
assert_eq!(script, deserialized);
}
}