mirror of
				https://github.com/neovim/neovim.git
				synced 2025-11-04 01:34:25 +00:00 
			
		
		
		
	fix(treesitter): correctly calculate bytes for text sources (#23655)
Fixes #20419
This commit is contained in:
		@@ -143,6 +143,29 @@ function M.contains(r1, r2)
 | 
				
			|||||||
  return true
 | 
					  return true
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					--- @param source integer|string
 | 
				
			||||||
 | 
					--- @param index integer
 | 
				
			||||||
 | 
					--- @return integer
 | 
				
			||||||
 | 
					local function get_offset(source, index)
 | 
				
			||||||
 | 
					  if index == 0 then
 | 
				
			||||||
 | 
					    return 0
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if type(source) == 'number' then
 | 
				
			||||||
 | 
					    return api.nvim_buf_get_offset(source, index)
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  local byte = 0
 | 
				
			||||||
 | 
					  local next_offset = source:gmatch('()\n')
 | 
				
			||||||
 | 
					  local line = 1
 | 
				
			||||||
 | 
					  while line <= index do
 | 
				
			||||||
 | 
					    byte = next_offset() --[[@as integer]]
 | 
				
			||||||
 | 
					    line = line + 1
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return byte
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@private
 | 
					---@private
 | 
				
			||||||
---@param source integer|string
 | 
					---@param source integer|string
 | 
				
			||||||
---@param range Range
 | 
					---@param range Range
 | 
				
			||||||
@@ -152,19 +175,10 @@ function M.add_bytes(source, range)
 | 
				
			|||||||
    return range --[[@as Range6]]
 | 
					    return range --[[@as Range6]]
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
 | 
					  local start_row, start_col, end_row, end_col = M.unpack4(range)
 | 
				
			||||||
  local start_byte = 0
 | 
					 | 
				
			||||||
  local end_byte = 0
 | 
					 | 
				
			||||||
  -- TODO(vigoux): proper byte computation here, and account for EOL ?
 | 
					  -- TODO(vigoux): proper byte computation here, and account for EOL ?
 | 
				
			||||||
  if type(source) == 'number' then
 | 
					  local start_byte = get_offset(source, start_row) + start_col
 | 
				
			||||||
    -- Easy case, this is a buffer parser
 | 
					  local end_byte = get_offset(source, end_row) + end_col
 | 
				
			||||||
    start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
 | 
					 | 
				
			||||||
    end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
 | 
					 | 
				
			||||||
  elseif type(source) == 'string' then
 | 
					 | 
				
			||||||
    -- string parser, single `\n` delimited string
 | 
					 | 
				
			||||||
    start_byte = vim.fn.byteidx(source, start_col)
 | 
					 | 
				
			||||||
    end_byte = vim.fn.byteidx(source, end_col)
 | 
					 | 
				
			||||||
  end
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return { start_row, start_col, start_byte, end_row, end_col, end_byte }
 | 
					  return { start_row, start_col, start_byte, end_row, end_col, end_byte }
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -486,7 +486,6 @@ end]]
 | 
				
			|||||||
    eq({ 'any-of?', 'contains?', 'eq?', 'has-ancestor?', 'has-parent?', 'is-main?', 'lua-match?', 'match?', 'vim-match?' }, res_list)
 | 
					    eq({ 'any-of?', 'contains?', 'eq?', 'has-ancestor?', 'has-parent?', 'is-main?', 'lua-match?', 'match?', 'vim-match?' }, res_list)
 | 
				
			||||||
  end)
 | 
					  end)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  it('allows to set simple ranges', function()
 | 
					  it('allows to set simple ranges', function()
 | 
				
			||||||
    insert(test_text)
 | 
					    insert(test_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -528,6 +527,7 @@ end]]
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    eq(range_tbl, { { { 0, 0, 0, 17, 1, 508 } } })
 | 
					    eq(range_tbl, { { { 0, 0, 0, 17, 1, 508 } } })
 | 
				
			||||||
  end)
 | 
					  end)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  it("allows to set complex ranges", function()
 | 
					  it("allows to set complex ranges", function()
 | 
				
			||||||
    insert(test_text)
 | 
					    insert(test_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -992,4 +992,58 @@ int x = INT_MAX;
 | 
				
			|||||||
    }, run_query())
 | 
					    }, run_query())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  end)
 | 
					  end)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  it('handles ranges when source is a multiline string (#20419)', function()
 | 
				
			||||||
 | 
					    local source = [==[
 | 
				
			||||||
 | 
					      vim.cmd[[
 | 
				
			||||||
 | 
					        set number
 | 
				
			||||||
 | 
					        set cmdheight=2
 | 
				
			||||||
 | 
					        set lastsatus=2
 | 
				
			||||||
 | 
					      ]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      set query = [[;; query
 | 
				
			||||||
 | 
					        ((function_call
 | 
				
			||||||
 | 
					          name: [
 | 
				
			||||||
 | 
					            (identifier) @_cdef_identifier
 | 
				
			||||||
 | 
					            (_ _ (identifier) @_cdef_identifier)
 | 
				
			||||||
 | 
					          ]
 | 
				
			||||||
 | 
					          arguments: (arguments (string content: _ @injection.content)))
 | 
				
			||||||
 | 
					          (#set! injection.language "c")
 | 
				
			||||||
 | 
					          (#eq? @_cdef_identifier "cdef"))
 | 
				
			||||||
 | 
					      ]]
 | 
				
			||||||
 | 
					    ]==]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    local r = exec_lua([[
 | 
				
			||||||
 | 
					      local parser = vim.treesitter.get_string_parser(..., 'lua')
 | 
				
			||||||
 | 
					      parser:parse()
 | 
				
			||||||
 | 
					      local ranges = {}
 | 
				
			||||||
 | 
					      parser:for_each_tree(function(tstree, tree)
 | 
				
			||||||
 | 
					        ranges[tree:lang()] = { tstree:root():range(true) }
 | 
				
			||||||
 | 
					      end)
 | 
				
			||||||
 | 
					      return ranges
 | 
				
			||||||
 | 
					    ]], source)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    eq({
 | 
				
			||||||
 | 
					      lua = { 0, 6, 6, 16, 4, 438 },
 | 
				
			||||||
 | 
					      query = { 6, 20, 113, 15, 6, 431 },
 | 
				
			||||||
 | 
					      vim = { 1, 0, 16, 4, 6, 89 }
 | 
				
			||||||
 | 
					    }, r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    -- The above ranges are provided directly from treesitter, however query directives may mutate
 | 
				
			||||||
 | 
					    -- the ranges but only provide a Range4. Strip the byte entries from the ranges and make sure
 | 
				
			||||||
 | 
					    -- add_bytes() produces the same result.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    local rb = exec_lua([[
 | 
				
			||||||
 | 
					      local r, source = ...
 | 
				
			||||||
 | 
					      local add_bytes = require('vim.treesitter._range').add_bytes
 | 
				
			||||||
 | 
					      for lang, range in pairs(r) do
 | 
				
			||||||
 | 
					        r[lang] = {range[1], range[2], range[4], range[5]}
 | 
				
			||||||
 | 
					        r[lang] = add_bytes(source, r[lang])
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					      return r
 | 
				
			||||||
 | 
					    ]], r, source)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    eq(rb, r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  end)
 | 
				
			||||||
end)
 | 
					end)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user