mirror of
https://github.com/neovim/neovim.git
synced 2026-03-28 11:22:03 +00:00
Problem treesitter select over-complicates visual selection. Solution make it use same visual selection logic as lsp.
562 lines
16 KiB
Lua
562 lines
16 KiB
Lua
local Range = require('vim.treesitter._range')
|
|
|
|
--- This is (currently only) used for saving what child one is in when doing
|
|
--- `select_parent` so that if they later `select_child` on the parent-node,
|
|
--- they get back to the child-node they were in instead of the parents first
|
|
--- child-node.
|
|
---
|
|
--- @type {[integer]:vim.treesitter.select.node,[any]:any}
|
|
local history = {
|
|
--- @type integer?
|
|
bufnr = nil,
|
|
|
|
--- @type integer?
|
|
changedtick = nil,
|
|
|
|
--- @type string?
|
|
current_node_id = nil,
|
|
}
|
|
|
|
--- The reason for a wrapper around `TSNode` is because we need to store the
|
|
--- information about which tstree-range they are in (as a tstree may be
|
|
--- disjointed), where region is the return value of
|
|
--- `TSTree:included_ranges(false)` with next to eachother ranges combined
|
|
--- (e.g. {{0,0,1,1},{1,1,2,2}} -> {{0,0,2,2}}).
|
|
---
|
|
--- @class vim.treesitter.select.node
|
|
--- @field node TSNode
|
|
--- @field top vim.treesitter.select.node.top
|
|
|
|
--- @class vim.treesitter.select.node.top: vim.treesitter.select.node
|
|
--- @field ltree vim.treesitter.LanguageTree
|
|
--- @field region Range4
|
|
|
|
local M = {}
|
|
|
|
--- @param node vim.treesitter.select.node
|
|
--- @return string
|
|
local function node_id(node)
|
|
return ('%s:%s'):format(table.concat({ unpack(node.top.region) }, ':'), node.node:id())
|
|
end
|
|
|
|
--- @param node vim.treesitter.select.node
|
|
--- @return Range4
|
|
local function node_range(node)
|
|
local node_range_ = { node.node:range() }
|
|
|
|
return Range.intersection(node.top.region, node_range_) or { 0, 0, 0, 0 }
|
|
end
|
|
|
|
--- @param node1 vim.treesitter.select.node
|
|
--- @param node2 vim.treesitter.select.node
|
|
--- @return boolean
|
|
local function node_is_same_range(node1, node2)
|
|
return Range.equal(node_range(node1), node_range(node2))
|
|
end
|
|
|
|
--- @param node vim.treesitter.select.node
|
|
--- @return boolean
|
|
local function node_is_size_0(node)
|
|
local srow, scol, erow, ecol = Range.unpack4(node_range(node))
|
|
return srow == erow and scol == ecol
|
|
end
|
|
|
|
--- @param tsnode TSNode
|
|
--- @param relative vim.treesitter.select.node
|
|
--- @return vim.treesitter.select.node
|
|
local function create_node(tsnode, relative)
|
|
assert(tsnode:tree():root():equal(relative.top.node))
|
|
|
|
--- @type vim.treesitter.select.node
|
|
return {
|
|
node = tsnode,
|
|
top = relative.top,
|
|
}
|
|
end
|
|
|
|
--- @param tree TSTree
|
|
--- @return Range4[]
|
|
local function tree_get_ranges(tree)
|
|
--- @type Range4[]
|
|
local regions = {}
|
|
for _, tree_range in ipairs(tree:included_ranges(false)) do
|
|
local prev_region = regions[#regions]
|
|
|
|
if prev_region and prev_region[3] == tree_range[1] and prev_region[4] == tree_range[2] then
|
|
regions[#regions] = { prev_region[1], prev_region[2], tree_range[3], tree_range[4] }
|
|
else
|
|
table.insert(regions, tree_range)
|
|
end
|
|
end
|
|
return regions
|
|
end
|
|
|
|
--- @param tree TSTree
|
|
--- @param region Range4
|
|
--- @param ltree vim.treesitter.LanguageTree
|
|
--- @return vim.treesitter.select.node.top
|
|
local function create_top_node(tree, region, ltree)
|
|
--- @type vim.treesitter.select.node.top
|
|
local self = {
|
|
node = tree:root(),
|
|
top = {} --[[@as any]],
|
|
ltree = ltree,
|
|
region = region,
|
|
}
|
|
self.top = self
|
|
return self
|
|
end
|
|
|
|
--- @param node1 vim.treesitter.select.node.top
|
|
--- @param node2 vim.treesitter.select.node.top
|
|
--- @return boolean
|
|
local function top_node_is_higher_priority(node1, node2)
|
|
local srow1, scol1, erow1, ecol1 = Range.unpack4(node_range(node1))
|
|
local srow2, scol2, erow2, ecol2 = Range.unpack4(node_range(node2))
|
|
|
|
if M.TEST_SWITCH_PRIORITY then
|
|
if Range.cmp_pos.ne(srow1, scol1, srow2, scol2) then
|
|
return Range.cmp_pos.lt(srow1, scol1, srow2, scol2)
|
|
elseif Range.cmp_pos.ne(erow1, ecol1, erow2, ecol2) then
|
|
return Range.cmp_pos.lt(erow1, ecol1, erow2, ecol2)
|
|
elseif node1.ltree:lang() ~= node2.ltree:lang() then
|
|
return node1.ltree:lang() > node2.ltree:lang()
|
|
end
|
|
return node1.node:id() > node2.node:id()
|
|
else
|
|
if Range.cmp_pos.ne(srow1, scol1, srow2, scol2) then
|
|
return Range.cmp_pos.gt(srow1, scol1, srow2, scol2)
|
|
elseif Range.cmp_pos.ne(erow1, ecol1, erow2, ecol2) then
|
|
return Range.cmp_pos.gt(erow1, ecol1, erow2, ecol2)
|
|
elseif node1.ltree:lang() ~= node2.ltree:lang() then
|
|
return node1.ltree:lang() < node2.ltree:lang()
|
|
end
|
|
return node1.node:id() < node2.node:id()
|
|
end
|
|
end
|
|
|
|
--- @param range Range4
|
|
--- @param top_node vim.treesitter.select.node.top?
|
|
--- @param parent_chain vim.treesitter.select.node[]?
|
|
--- @return vim.treesitter.select.node|false|nil nil: no parser, false: outside of root-node
|
|
--- @return vim.treesitter.select.node[] either `parent_chain` or `alternative_nodes`
|
|
local function get_node(range, top_node, parent_chain)
|
|
parent_chain = parent_chain or {}
|
|
|
|
if not top_node then
|
|
local parser = vim.treesitter.get_parser(nil, nil, { error = false })
|
|
if not parser then
|
|
return nil, {}
|
|
end
|
|
|
|
local tree = assert(parser:parse(range))[1]
|
|
top_node = create_top_node(tree, assert(tree:included_ranges(false)[1]), parser)
|
|
|
|
if not Range.contains(node_range(top_node), range) then
|
|
return false, { top_node } --[[alternative_nodes]]
|
|
end
|
|
end
|
|
|
|
assert(Range.contains(node_range(top_node), range))
|
|
|
|
--- @param node vim.treesitter.select.node|vim.treesitter.select.node.top
|
|
--- @return vim.treesitter.select.node|vim.treesitter.select.node.top
|
|
local function node_ignore_overlapped_handle_injection(node)
|
|
for _, child in pairs(top_node.ltree:children()) do
|
|
for _, child_tree in ipairs(child:trees()) do
|
|
for _, child_region in ipairs(tree_get_ranges(child_tree)) do
|
|
local child_root_node_range = { child_tree:root():range() }
|
|
local child_range = Range.intersection(child_region, child_root_node_range)
|
|
|
|
local child_top_node = create_top_node(child_tree, child_region, child)
|
|
if
|
|
child_range
|
|
and Range.contains(child_range, range)
|
|
and (
|
|
not node.ltree
|
|
or top_node_is_higher_priority(
|
|
node --[[@as vim.treesitter.select.node.top]],
|
|
child_top_node
|
|
)
|
|
)
|
|
then
|
|
return node_ignore_overlapped_handle_injection(child_top_node)
|
|
elseif child_range and Range.intercepts(node_range(node), child_range) then
|
|
local child_parent_tsnode =
|
|
assert(top_node.node:named_descendant_for_range(unpack(child_range)))
|
|
|
|
if
|
|
(not node.ltree and vim.treesitter.is_ancestor(child_parent_tsnode, node.node))
|
|
or (
|
|
node.ltree
|
|
and top_node_is_higher_priority(
|
|
node --[[@as vim.treesitter.select.node.top]],
|
|
child_top_node
|
|
)
|
|
)
|
|
then
|
|
return create_node(child_parent_tsnode, top_node)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
return node
|
|
end
|
|
|
|
local tsnode = assert(top_node.node:named_descendant_for_range(unpack(range)))
|
|
local node = create_node(tsnode, top_node)
|
|
|
|
node = node_ignore_overlapped_handle_injection(node)
|
|
if node.ltree then
|
|
local root_node_range = { node.node:range() }
|
|
local tree_range = node.top.region
|
|
local actual_range = assert(Range.intersection(tree_range, root_node_range))
|
|
local parent_tsnode = assert(top_node.node:named_descendant_for_range(unpack(actual_range)))
|
|
table.insert(parent_chain, create_node(parent_tsnode, top_node))
|
|
|
|
--- @cast node vim.treesitter.select.node.top
|
|
return get_node(range, node, parent_chain), parent_chain
|
|
end
|
|
--- @cast node vim.treesitter.select.node
|
|
|
|
return node, parent_chain
|
|
end
|
|
|
|
--- @param node vim.treesitter.select.node
|
|
--- @param parent_chain vim.treesitter.select.node[]
|
|
--- @nodiscard
|
|
--- @return vim.treesitter.select.node?
|
|
--- @return vim.treesitter.select.node.top?
|
|
local function node_get_parent_no_normalize(node, parent_chain)
|
|
local parent = node.node:parent()
|
|
if parent then
|
|
return create_node(parent, node)
|
|
end
|
|
|
|
return table.remove(parent_chain)
|
|
end
|
|
|
|
--- @param node vim.treesitter.select.node
|
|
--- @return vim.treesitter.select.node
|
|
local function node_normalize_up(node, parent_chain)
|
|
while true do
|
|
local parent = node_get_parent_no_normalize(node, parent_chain)
|
|
if parent and node_is_same_range(parent, node) then
|
|
node = parent
|
|
else
|
|
table.insert(parent_chain, parent)
|
|
|
|
return node
|
|
end
|
|
end
|
|
--- @diagnostic disable-next-line: missing-return
|
|
end
|
|
|
|
--- @param nodes vim.treesitter.select.node[]
|
|
--- @param node vim.treesitter.select.node.top
|
|
local function insert_remove_overlapped(nodes, node)
|
|
local n = 1
|
|
while nodes[n] do
|
|
if Range.intercepts(node_range(nodes[n]), node_range(node)) then
|
|
if
|
|
not nodes
|
|
[n] --[[@as any]]
|
|
.ltree
|
|
or top_node_is_higher_priority(nodes[n] --[[@as vim.treesitter.select.node.top]], node)
|
|
then
|
|
table.remove(nodes, n)
|
|
else
|
|
return
|
|
end
|
|
else
|
|
local nrow, ncol, _, _ = Range.unpack4(node_range(nodes[n]))
|
|
local _, _, erow, ecol = Range.unpack4(node_range(node))
|
|
if Range.cmp_pos.le(erow, ecol, nrow, ncol) then
|
|
table.insert(nodes, n, node)
|
|
return
|
|
end
|
|
|
|
n = n + 1
|
|
end
|
|
end
|
|
|
|
table.insert(nodes, node)
|
|
end
|
|
|
|
--- @param node vim.treesitter.select.node
|
|
--- @return vim.treesitter.select.node[]
|
|
local function node_get_children_no_normalize(node)
|
|
--- @param child_ TSNode
|
|
--- @return vim.treesitter.select.node
|
|
local children = vim.tbl_map(function(child_)
|
|
return create_node(child_, node)
|
|
end, node.node:named_children())
|
|
|
|
node.top.ltree:parse(node_range(node))
|
|
|
|
for _, child in pairs(node.top.ltree:children()) do
|
|
for _, child_tree in ipairs(child:trees()) do
|
|
for _, child_region in ipairs(tree_get_ranges(child_tree)) do
|
|
local child_root_node_range = { child_tree:root():range() }
|
|
local child_range = Range.intersection(child_region, child_root_node_range)
|
|
|
|
if child_range and Range.contains(node_range(node), child_range) then
|
|
local child_parent_tsnode =
|
|
assert(node.top.node:named_descendant_for_range(unpack(child_range)))
|
|
|
|
if node.node:equal(child_parent_tsnode) then
|
|
local child_node = create_top_node(child_tree, child_region, child)
|
|
|
|
insert_remove_overlapped(children, child_node)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
return children
|
|
end
|
|
|
|
--- @param range Range4
|
|
--- @param node vim.treesitter.select.node
|
|
--- @return vim.treesitter.select.node?
|
|
local function get_node_contained_in_range(range, node)
|
|
for _, child in ipairs(node_get_children_no_normalize(node)) do
|
|
if Range.contains(range, node_range(child)) and not node_is_size_0(child) then
|
|
return child
|
|
elseif Range.intercepts(range, node_range(child)) and not node_is_size_0(child) then
|
|
local smallest_node = get_node_contained_in_range(range, child)
|
|
|
|
if smallest_node then
|
|
return smallest_node
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
--- @param node vim.treesitter.select.node
|
|
--- @return vim.treesitter.select.node
|
|
local function node_normalize_down(node)
|
|
for _, child in ipairs(node_get_children_no_normalize(node)) do
|
|
if node_is_same_range(node, child) then
|
|
return node_normalize_down(child)
|
|
end
|
|
end
|
|
|
|
return node
|
|
end
|
|
|
|
local function visual_select(range)
|
|
assert(type(range) == 'table')
|
|
local srow, scol, erow, ecol = Range.unpack4(range)
|
|
local cursor_other_end_of_visual = false
|
|
|
|
local vcol, vrow = vim.fn.col('v'), vim.fn.line('v')
|
|
local ccol, cline = vim.fn.col('.'), vim.fn.line('.')
|
|
if vrow > cline or (vrow == cline and vcol > ccol) then
|
|
cursor_other_end_of_visual = true
|
|
end
|
|
|
|
if ecol == 0 then
|
|
erow = erow - 1
|
|
ecol = #vim.fn.getline(erow + 1) + 1
|
|
end
|
|
|
|
vim.fn.setpos("'<", { 0, srow + 1, scol + 1, 0 })
|
|
vim.fn.setpos("'>", { 0, erow + 1, ecol, 0 })
|
|
if cursor_other_end_of_visual then
|
|
vim.cmd.normal({ 'gvo', bang = true })
|
|
else
|
|
vim.cmd.normal({ 'gv', bang = true })
|
|
end
|
|
end
|
|
|
|
--- @return Range4
|
|
local function get_selection()
|
|
local pos1 = vim.fn.getpos('v')
|
|
local pos2 = vim.fn.getpos('.')
|
|
if pos1[2] > pos2[2] or (pos1[2] == pos2[2] and pos1[3] > pos2[3]) then
|
|
--- @type Range4,Range4
|
|
pos1, pos2 = pos2, pos1
|
|
end
|
|
local range = { pos1[2] - 1, pos1[3] - 1, pos2[2] - 1, pos2[3] }
|
|
|
|
if range[4] == #vim.fn.getline(range[3] + 1) + 1 then
|
|
range[3] = range[3] + 1
|
|
range[4] = 0
|
|
end
|
|
|
|
return range
|
|
end
|
|
|
|
local function get_parent_from_range(range)
|
|
local node, parent_chain = get_node(range)
|
|
|
|
if node == false then
|
|
return (assert(parent_chain[1]))
|
|
end
|
|
|
|
if not node then
|
|
return
|
|
end
|
|
|
|
if not Range.equal(range, node_range(node)) then
|
|
return node
|
|
end
|
|
|
|
node = node_normalize_up(node, parent_chain)
|
|
|
|
local parent = node_get_parent_no_normalize(node, parent_chain)
|
|
|
|
if parent then
|
|
if
|
|
history.bufnr ~= vim.api.nvim_get_current_buf()
|
|
or history.changedtick ~= vim.b.changedtick
|
|
or history.current_node_id ~= node_id(node)
|
|
then
|
|
history = {
|
|
bufnr = vim.api.nvim_get_current_buf(),
|
|
changedtick = vim.b.changedtick,
|
|
}
|
|
end
|
|
table.insert(history, node)
|
|
history.current_node_id = node_id(parent)
|
|
|
|
return parent
|
|
end
|
|
end
|
|
|
|
local function get_child_from_range(range)
|
|
local node, alternative_child_nodes = get_node(range)
|
|
|
|
if node == false then
|
|
return (assert(alternative_child_nodes[1]))
|
|
end
|
|
|
|
if not node then
|
|
return
|
|
end
|
|
|
|
node = node_normalize_down(node)
|
|
|
|
if not Range.equal(range, node_range(node)) then
|
|
history = {}
|
|
|
|
local smallest_node = get_node_contained_in_range(range, node)
|
|
if smallest_node then
|
|
return smallest_node
|
|
end
|
|
|
|
return node
|
|
end
|
|
|
|
if
|
|
history.bufnr == vim.api.nvim_get_current_buf()
|
|
and history.changedtick == vim.b.changedtick
|
|
and history.current_node_id == node_id(node)
|
|
then
|
|
--- @type vim.treesitter.select.node
|
|
local child = table.remove(history)
|
|
if child then
|
|
history.current_node_id = node_id(child)
|
|
|
|
return child
|
|
end
|
|
end
|
|
history = {}
|
|
|
|
for _, child in ipairs(node_get_children_no_normalize(node)) do
|
|
if not node_is_size_0(child) then
|
|
return child
|
|
end
|
|
end
|
|
|
|
return node
|
|
end
|
|
|
|
--- @param prev boolean
|
|
local function get_sibling_from_range(range, prev)
|
|
local node, parent_chain = get_node(range)
|
|
if not node then
|
|
return
|
|
end
|
|
|
|
node = node_normalize_up(node, parent_chain)
|
|
local parent = node_get_parent_no_normalize(node, parent_chain)
|
|
if not parent then
|
|
return
|
|
end
|
|
|
|
local siblings = node_get_children_no_normalize(parent)
|
|
|
|
--- @type integer?
|
|
local idx
|
|
for n, child in ipairs(siblings) do
|
|
if node_id(child) == node_id(node) then
|
|
idx = n + (prev and -1 or 1)
|
|
break
|
|
end
|
|
end
|
|
assert(idx)
|
|
|
|
while siblings[idx] and node_is_size_0(siblings[idx]) do
|
|
idx = idx + (prev and -1 or 1)
|
|
end
|
|
|
|
if siblings[idx] then
|
|
return siblings[idx]
|
|
end
|
|
end
|
|
|
|
local function get_next_from_range(range)
|
|
return get_sibling_from_range(range, false)
|
|
end
|
|
|
|
local function get_prev_from_range(range)
|
|
return get_sibling_from_range(range, true)
|
|
end
|
|
|
|
--- @param count integer
|
|
--- @param fn fun(range: Range4): vim.treesitter.select.node
|
|
local function repeate_apply_range(count, fn)
|
|
local range = get_selection()
|
|
|
|
for _ = 1, count or 1 do
|
|
local node = fn(range)
|
|
|
|
if not node then
|
|
break
|
|
end
|
|
|
|
range = node_range(node)
|
|
end
|
|
|
|
if range and count ~= 0 then
|
|
visual_select(range)
|
|
end
|
|
end
|
|
|
|
--- @param count integer
|
|
function M.select_parent(count)
|
|
repeate_apply_range(count, get_parent_from_range)
|
|
end
|
|
|
|
--- @param count integer
|
|
function M.select_child(count)
|
|
repeate_apply_range(count, get_child_from_range)
|
|
end
|
|
|
|
--- @param count integer
|
|
function M.select_next(count)
|
|
repeate_apply_range(count, get_next_from_range)
|
|
end
|
|
|
|
--- @param count integer
|
|
function M.select_prev(count)
|
|
repeate_apply_range(count, get_prev_from_range)
|
|
end
|
|
|
|
return M
|