refactor(lsp): flat_map with mutable accumulator (#15567)

# Description

Mainly performance improvement of lsp operations involving flat_map on
AST nodes.
Previous flat_map traversing is functional, which is a nice property to
have, but the heavy cost of vector collection on each tree node makes it
undesirable.

This PR mitigates the problem with a mutable accumulator.

# User-Facing Changes

Should be none.

# Tests + Formatting

# After Submitting
This commit is contained in:
zc he 2025-04-15 20:21:23 +08:00 committed by GitHub
parent 8c4d3eaa7e
commit e5f589ccdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 188 additions and 194 deletions

View File

@ -46,6 +46,10 @@ fn command_name_span_from_call_head(
head_span: Span, head_span: Span,
) -> Span { ) -> Span {
let name = working_set.get_decl(decl_id).name(); let name = working_set.get_decl(decl_id).name();
// shortcut for most cases
if name.len() == head_span.end.saturating_sub(head_span.start) {
return head_span;
}
let head_content = working_set.get_span_contents(head_span); let head_content = working_set.get_span_contents(head_span);
let mut head_words = head_content.split(|c| *c == b' ').collect::<Vec<_>>(); let mut head_words = head_content.split(|c| *c == b' ').collect::<Vec<_>>();
let mut name_words = name.split(' ').collect::<Vec<_>>(); let mut name_words = name.split(' ').collect::<Vec<_>>();
@ -500,38 +504,34 @@ fn find_reference_by_id_in_expr(
expr: &Expression, expr: &Expression,
working_set: &StateWorkingSet, working_set: &StateWorkingSet,
id: &Id, id: &Id,
) -> Option<Vec<Span>> { ) -> Vec<Span> {
let closure = |e| find_reference_by_id_in_expr(e, working_set, id);
match (&expr.expr, id) { match (&expr.expr, id) {
(Expr::Var(vid1), Id::Variable(vid2, _)) if *vid1 == *vid2 => Some(vec![Span::new( (Expr::Var(vid1), Id::Variable(vid2, _)) if *vid1 == *vid2 => vec![Span::new(
// we want to exclude the `$` sign for renaming // we want to exclude the `$` sign for renaming
expr.span.start.saturating_add(1), expr.span.start.saturating_add(1),
expr.span.end, expr.span.end,
)]), )],
(Expr::VarDecl(vid1), Id::Variable(vid2, _)) if *vid1 == *vid2 => Some(vec![expr.span]), (Expr::VarDecl(vid1), Id::Variable(vid2, _)) if *vid1 == *vid2 => vec![expr.span],
// also interested in `var_id` in call.arguments of `use` command // also interested in `var_id` in call.arguments of `use` command
// and `module_id` in `module` command // and `module_id` in `module` command
(Expr::Call(call), _) => { (Expr::Call(call), _) => {
let mut occurs: Vec<Span> = call
.arguments
.iter()
.filter_map(|arg| arg.expr())
.flat_map(|e| e.flat_map(working_set, &closure))
.collect();
if matches!(id, Id::Declaration(decl_id) if call.decl_id == *decl_id) { if matches!(id, Id::Declaration(decl_id) if call.decl_id == *decl_id) {
occurs.push(command_name_span_from_call_head( vec![command_name_span_from_call_head(
working_set, working_set,
call.decl_id, call.decl_id,
call.head, call.head,
)); )]
return Some(occurs);
} }
if let Some((_, span_found)) = try_find_id_in_misc(call, working_set, None, Some(id)) { // Check for misc matches (use, module, etc.)
occurs.push(span_found); else if let Some((_, span_found)) =
try_find_id_in_misc(call, working_set, None, Some(id))
{
vec![span_found]
} else {
vec![]
} }
Some(occurs)
} }
_ => None, _ => vec![],
} }
} }
@ -540,7 +540,8 @@ pub(crate) fn find_reference_by_id(
working_set: &StateWorkingSet, working_set: &StateWorkingSet,
id: &Id, id: &Id,
) -> Vec<Span> { ) -> Vec<Span> {
ast.flat_map(working_set, &|e| { let mut results = Vec::new();
find_reference_by_id_in_expr(e, working_set, id) let closure = |e| find_reference_by_id_in_expr(e, working_set, id);
}) ast.flat_map(working_set, &closure, &mut results);
results
} }

View File

@ -26,14 +26,9 @@ fn extract_inlay_hints_from_expression(
working_set: &StateWorkingSet, working_set: &StateWorkingSet,
offset: &usize, offset: &usize,
file: &FullTextDocument, file: &FullTextDocument,
) -> Option<Vec<InlayHint>> { ) -> Vec<InlayHint> {
let closure = |e| extract_inlay_hints_from_expression(e, working_set, offset, file);
match &expr.expr { match &expr.expr {
Expr::BinaryOp(lhs, op, rhs) => { Expr::BinaryOp(lhs, op, rhs) => {
let mut hints: Vec<InlayHint> = [lhs, op, rhs]
.into_iter()
.flat_map(|e| e.flat_map(working_set, &closure))
.collect();
if let Expr::Operator(Operator::Assignment(_)) = op.expr { if let Expr::Operator(Operator::Assignment(_)) = op.expr {
let position = span_to_range(&lhs.span, file, *offset).end; let position = span_to_range(&lhs.span, file, *offset).end;
let type_rhs = type_short_name(&rhs.ty); let type_rhs = type_short_name(&rhs.ty);
@ -43,7 +38,7 @@ fn extract_inlay_hints_from_expression(
(_, "any") => type_lhs, (_, "any") => type_lhs,
_ => type_lhs, _ => type_lhs,
}; };
hints.push(InlayHint { vec![InlayHint {
kind: Some(InlayHintKind::TYPE), kind: Some(InlayHintKind::TYPE),
label: InlayHintLabel::String(format!(": {}", type_string)), label: InlayHintLabel::String(format!(": {}", type_string)),
position, position,
@ -52,9 +47,10 @@ fn extract_inlay_hints_from_expression(
data: None, data: None,
padding_left: None, padding_left: None,
padding_right: None, padding_right: None,
}) }]
} else {
vec![]
} }
Some(hints)
} }
Expr::VarDecl(var_id) => { Expr::VarDecl(var_id) => {
let position = span_to_range(&expr.span, file, *offset).end; let position = span_to_range(&expr.span, file, *offset).end;
@ -69,27 +65,30 @@ fn extract_inlay_hints_from_expression(
})) }))
.contains(':') .contains(':')
{ {
return Some(Vec::new()); return vec![];
} }
let var = working_set.get_variable(*var_id); let var = working_set.get_variable(*var_id);
let type_string = type_short_name(&var.ty); let type_string = type_short_name(&var.ty);
Some(vec![ vec![InlayHint {
(InlayHint { kind: Some(InlayHintKind::TYPE),
kind: Some(InlayHintKind::TYPE), label: InlayHintLabel::String(format!(": {}", type_string)),
label: InlayHintLabel::String(format!(": {}", type_string)), position,
position, text_edits: None,
text_edits: None, tooltip: None,
tooltip: None, data: None,
data: None, padding_left: None,
padding_left: None, padding_right: None,
padding_right: None, }]
}),
])
} }
Expr::Call(call) => { Expr::Call(call) => {
let decl = working_set.get_decl(call.decl_id); let decl = working_set.get_decl(call.decl_id);
// skip those defined outside of the project // skip those defined outside of the project
working_set.get_block(decl.block_id()?).span?; let Some(block_id) = decl.block_id() else {
return vec![];
};
if working_set.get_block(block_id).span.is_none() {
return vec![];
};
let signatures = decl.signature(); let signatures = decl.signature();
let signatures = [ let signatures = [
signatures.required_positional, signatures.required_positional,
@ -102,17 +101,11 @@ fn extract_inlay_hints_from_expression(
for arg in arguments { for arg in arguments {
match arg { match arg {
// skip the rest when spread/unknown arguments encountered // skip the rest when spread/unknown arguments encountered
Argument::Spread(expr) | Argument::Unknown(expr) => { Argument::Spread(_) | Argument::Unknown(_) => {
hints.extend(expr.flat_map(working_set, &closure));
sig_idx = signatures.len(); sig_idx = signatures.len();
continue; continue;
} }
// skip current for flags Argument::Positional(_) => {
Argument::Named((_, _, Some(expr))) => {
hints.extend(expr.flat_map(working_set, &closure));
continue;
}
Argument::Positional(expr) => {
if let Some(sig) = signatures.get(sig_idx) { if let Some(sig) = signatures.get(sig_idx) {
sig_idx += 1; sig_idx += 1;
let position = span_to_range(&arg.span(), file, *offset).start; let position = span_to_range(&arg.span(), file, *offset).start;
@ -130,16 +123,16 @@ fn extract_inlay_hints_from_expression(
padding_right: None, padding_right: None,
}); });
} }
hints.extend(expr.flat_map(working_set, &closure));
} }
// skip current for flags
_ => { _ => {
continue; continue;
} }
} }
} }
Some(hints) hints
} }
_ => None, _ => vec![],
} }
} }
@ -154,9 +147,10 @@ impl LanguageServer {
offset: usize, offset: usize,
file: &FullTextDocument, file: &FullTextDocument,
) -> Vec<InlayHint> { ) -> Vec<InlayHint> {
block.flat_map(working_set, &|e| { let closure = |e| extract_inlay_hints_from_expression(e, working_set, &offset, file);
extract_inlay_hints_from_expression(e, working_set, &offset, file) let mut results = Vec::new();
}) block.flat_map(working_set, &closure, &mut results);
results
} }
} }

View File

@ -19,32 +19,21 @@ use crate::{span_to_range, LanguageServer};
fn extract_semantic_tokens_from_expression( fn extract_semantic_tokens_from_expression(
expr: &Expression, expr: &Expression,
working_set: &StateWorkingSet, working_set: &StateWorkingSet,
) -> Option<Vec<Span>> { ) -> Vec<Span> {
let closure = |e| extract_semantic_tokens_from_expression(e, working_set);
match &expr.expr { match &expr.expr {
Expr::Call(call) => { Expr::Call(call) => {
let command_name_bytes = working_set.get_span_contents(call.head); let command_name = working_set.get_decl(call.decl_id).name();
let head_span = if command_name_bytes.contains(&b' ') if command_name.contains(' ')
// Some keywords that are already highlighted properly, e.g. by tree-sitter-nu // Some keywords that are already highlighted properly, e.g. by tree-sitter-nu
&& !command_name_bytes.starts_with(b"export") && !command_name.starts_with("export")
&& !command_name_bytes.starts_with(b"overlay") && !command_name.starts_with("overlay")
{ {
vec![call.head] vec![call.head]
} else { } else {
vec![] vec![]
}; }
let spans = head_span
.into_iter()
.chain(
call.arguments
.iter()
.filter_map(|arg| arg.expr())
.flat_map(|e| e.flat_map(working_set, &closure)),
)
.collect();
Some(spans)
} }
_ => None, _ => vec![],
} }
} }
@ -67,14 +56,14 @@ impl LanguageServer {
offset: usize, offset: usize,
file: &FullTextDocument, file: &FullTextDocument,
) -> Vec<SemanticToken> { ) -> Vec<SemanticToken> {
let spans = block.flat_map(working_set, &|e| { let mut results = Vec::new();
extract_semantic_tokens_from_expression(e, working_set) let closure = |e| extract_semantic_tokens_from_expression(e, working_set);
}); block.flat_map(working_set, &closure, &mut results);
let mut last_token_line = 0; let mut last_token_line = 0;
let mut last_token_char = 0; let mut last_token_char = 0;
let mut last_span = Span::unknown(); let mut last_span = Span::unknown();
let mut tokens = vec![]; let mut tokens = vec![];
for sp in spans { for sp in results {
let range = span_to_range(&sp, file, offset); let range = span_to_range(&sp, file, offset);
// shouldn't happen // shouldn't happen
if sp < last_span { if sp < last_span {

View File

@ -15,17 +15,19 @@ pub enum FindMapResult<T> {
/// Trait for traversing the AST /// Trait for traversing the AST
pub trait Traverse { pub trait Traverse {
/// Generic function that do flat_map on an AST node /// Generic function that do flat_map on an AST node.
/// concatenates all recursive results on sub-expressions /// Concatenates all recursive results on sub-expressions
/// into the `results` accumulator.
/// ///
/// # Arguments /// # Arguments
/// * `f` - function that overrides the default behavior /// * `f` - function that generates leaf elements
fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T> /// * `results` - accumulator
fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F, results: &mut Vec<T>)
where where
F: Fn(&'a Expression) -> Option<Vec<T>>; F: Fn(&'a Expression) -> Vec<T>;
/// Generic function that do find_map on an AST node /// Generic function that do find_map on an AST node.
/// return the first Some /// Return the first result found by applying `f` on sub-expressions.
/// ///
/// # Arguments /// # Arguments
/// * `f` - function that overrides the default behavior /// * `f` - function that overrides the default behavior
@ -35,24 +37,18 @@ pub trait Traverse {
} }
impl Traverse for Block { impl Traverse for Block {
fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T> fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F, results: &mut Vec<T>)
where where
F: Fn(&'a Expression) -> Option<Vec<T>>, F: Fn(&'a Expression) -> Vec<T>,
{ {
self.pipelines for pipeline in self.pipelines.iter() {
.iter() for element in pipeline.elements.iter() {
.flat_map(|pipeline| { element.expr.flat_map(working_set, f, results);
pipeline.elements.iter().flat_map(|element| { if let Some(redir) = &element.redirection {
element.expr.flat_map(working_set, f).into_iter().chain( redir.flat_map(working_set, f, results);
element };
.redirection }
.as_ref() }
.map(|redir| redir.flat_map(working_set, f))
.unwrap_or_default(),
)
})
})
.collect()
} }
fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T> fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
@ -71,21 +67,19 @@ impl Traverse for Block {
} }
impl Traverse for PipelineRedirection { impl Traverse for PipelineRedirection {
fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T> fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F, results: &mut Vec<T>)
where where
F: Fn(&'a Expression) -> Option<Vec<T>>, F: Fn(&'a Expression) -> Vec<T>,
{ {
let recur = |expr: &'a Expression| expr.flat_map(working_set, f); let mut recur = |expr: &'a Expression| expr.flat_map(working_set, f, results);
match self { match self {
PipelineRedirection::Single { target, .. } => { PipelineRedirection::Single { target, .. } => target.expr().map(recur),
target.expr().map(recur).unwrap_or_default() PipelineRedirection::Separate { out, err } => {
out.expr().map(&mut recur);
err.expr().map(&mut recur)
} }
PipelineRedirection::Separate { out, err } => [out, err] };
.iter()
.filter_map(|t| t.expr())
.flat_map(recur)
.collect(),
}
} }
fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T> fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
@ -94,9 +88,7 @@ impl Traverse for PipelineRedirection {
{ {
let recur = |expr: &'a Expression| expr.find_map(working_set, f); let recur = |expr: &'a Expression| expr.find_map(working_set, f);
match self { match self {
PipelineRedirection::Single { target, .. } => { PipelineRedirection::Single { target, .. } => target.expr().and_then(recur),
target.expr().map(recur).unwrap_or_default()
}
PipelineRedirection::Separate { out, err } => { PipelineRedirection::Separate { out, err } => {
[out, err].iter().filter_map(|t| t.expr()).find_map(recur) [out, err].iter().filter_map(|t| t.expr()).find_map(recur)
} }
@ -105,87 +97,97 @@ impl Traverse for PipelineRedirection {
} }
impl Traverse for Expression { impl Traverse for Expression {
fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T> fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F, results: &mut Vec<T>)
where where
F: Fn(&'a Expression) -> Option<Vec<T>>, F: Fn(&'a Expression) -> Vec<T>,
{ {
// behavior overridden by f // leaf elements generated by `f` for this expression
if let Some(vec) = f(self) { results.extend(f(self));
return vec; let mut recur = |expr: &'a Expression| expr.flat_map(working_set, f, results);
}
let recur = |expr: &'a Expression| expr.flat_map(working_set, f);
match &self.expr { match &self.expr {
Expr::RowCondition(block_id) Expr::RowCondition(block_id)
| Expr::Subexpression(block_id) | Expr::Subexpression(block_id)
| Expr::Block(block_id) | Expr::Block(block_id)
| Expr::Closure(block_id) => { | Expr::Closure(block_id) => {
let block = working_set.get_block(block_id.to_owned()); let block = working_set.get_block(block_id.to_owned());
block.flat_map(working_set, f) block.flat_map(working_set, f, results)
}
Expr::Range(range) => {
for sub_expr in [&range.from, &range.next, &range.to].into_iter().flatten() {
recur(sub_expr);
}
}
Expr::Call(call) => {
for arg in &call.arguments {
if let Some(sub_expr) = arg.expr() {
recur(sub_expr);
}
}
}
Expr::ExternalCall(head, args) => {
recur(head.as_ref());
for arg in args {
recur(arg.expr());
}
} }
Expr::Range(range) => [&range.from, &range.next, &range.to]
.iter()
.filter_map(|e| e.as_ref())
.flat_map(recur)
.collect(),
Expr::Call(call) => call
.arguments
.iter()
.filter_map(|arg| arg.expr())
.flat_map(recur)
.collect(),
Expr::ExternalCall(head, args) => recur(head.as_ref())
.into_iter()
.chain(args.iter().flat_map(|arg| recur(arg.expr())))
.collect(),
Expr::UnaryNot(expr) | Expr::Collect(_, expr) => recur(expr.as_ref()), Expr::UnaryNot(expr) | Expr::Collect(_, expr) => recur(expr.as_ref()),
Expr::BinaryOp(lhs, op, rhs) => recur(lhs) Expr::BinaryOp(lhs, op, rhs) => {
.into_iter() recur(lhs);
.chain(recur(op)) recur(op);
.chain(recur(rhs)) recur(rhs);
.collect(), }
Expr::MatchBlock(matches) => matches Expr::MatchBlock(matches) => {
.iter() for (pattern, expr) in matches {
.flat_map(|(pattern, expr)| { pattern.flat_map(working_set, f, results);
pattern expr.flat_map(working_set, f, results);
.flat_map(working_set, f) }
.into_iter() }
.chain(recur(expr)) Expr::List(items) => {
}) for item in items {
.collect(), match item {
Expr::List(items) => items ListItem::Item(expr) | ListItem::Spread(_, expr) => recur(expr),
.iter() }
.flat_map(|item| match item { }
ListItem::Item(expr) | ListItem::Spread(_, expr) => recur(expr), }
}) Expr::Record(items) => {
.collect(), for item in items {
Expr::Record(items) => items match item {
.iter() RecordItem::Spread(_, expr) => recur(expr),
.flat_map(|item| match item { RecordItem::Pair(key, val) => {
RecordItem::Spread(_, expr) => recur(expr), recur(key);
RecordItem::Pair(key, val) => [key, val].into_iter().flat_map(recur).collect(), recur(val);
}) }
.collect(), }
Expr::Table(table) => table }
.columns }
.iter() Expr::Table(table) => {
.flat_map(recur) for column in &table.columns {
.chain(table.rows.iter().flat_map(|row| row.iter().flat_map(recur))) recur(column);
.collect(), }
for row in &table.rows {
for item in row {
recur(item);
}
}
}
Expr::ValueWithUnit(vu) => recur(&vu.expr), Expr::ValueWithUnit(vu) => recur(&vu.expr),
Expr::FullCellPath(fcp) => recur(&fcp.head), Expr::FullCellPath(fcp) => recur(&fcp.head),
Expr::Keyword(kw) => recur(&kw.expr), Expr::Keyword(kw) => recur(&kw.expr),
Expr::StringInterpolation(vec) | Expr::GlobInterpolation(vec, _) => { Expr::StringInterpolation(vec) | Expr::GlobInterpolation(vec, _) => {
vec.iter().flat_map(recur).collect() for item in vec {
recur(item);
}
}
Expr::AttributeBlock(ab) => {
for attr in &ab.attributes {
recur(&attr.expr);
}
recur(&ab.item);
} }
Expr::AttributeBlock(ab) => ab
.attributes
.iter()
.flat_map(|attr| recur(&attr.expr))
.chain(recur(&ab.item))
.collect(),
_ => Vec::new(), _ => (),
} };
} }
fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T> fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
@ -203,7 +205,9 @@ impl Traverse for Expression {
| Expr::Subexpression(block_id) | Expr::Subexpression(block_id)
| Expr::Block(block_id) | Expr::Block(block_id)
| Expr::Closure(block_id) => { | Expr::Closure(block_id) => {
let block = working_set.get_block(block_id.to_owned()); // Clone the block_id to create an owned value
let block_id = block_id.to_owned();
let block = working_set.get_block(block_id);
block.find_map(working_set, f) block.find_map(working_set, f)
} }
Expr::Range(range) => [&range.from, &range.next, &range.to] Expr::Range(range) => [&range.from, &range.next, &range.to]
@ -253,25 +257,31 @@ impl Traverse for Expression {
} }
impl Traverse for MatchPattern { impl Traverse for MatchPattern {
fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T> fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F, results: &mut Vec<T>)
where where
F: Fn(&'a Expression) -> Option<Vec<T>>, F: Fn(&'a Expression) -> Vec<T>,
{ {
let recur = |expr: &'a Expression| expr.flat_map(working_set, f); let mut recur_pattern =
let recur_pattern = |pattern: &'a MatchPattern| pattern.flat_map(working_set, f); |pattern: &'a MatchPattern| pattern.flat_map(working_set, f, results);
match &self.pattern { match &self.pattern {
Pattern::Expression(expr) => recur(expr), Pattern::Expression(expr) => expr.flat_map(working_set, f, results),
Pattern::List(patterns) | Pattern::Or(patterns) => { Pattern::List(patterns) | Pattern::Or(patterns) => {
patterns.iter().flat_map(recur_pattern).collect() for pattern in patterns {
recur_pattern(pattern);
}
} }
Pattern::Record(entries) => { Pattern::Record(entries) => {
entries.iter().flat_map(|(_, p)| recur_pattern(p)).collect() for (_, p) in entries {
recur_pattern(p);
}
} }
_ => Vec::new(), _ => (),
};
if let Some(g) = self.guard.as_ref() {
g.flat_map(working_set, f, results);
} }
.into_iter()
.chain(self.guard.as_ref().map(|g| recur(g)).unwrap_or_default())
.collect()
} }
fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T> fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>