fix #19500; remove find optimization [backport: 1.6] (#19714)

* remove find optimization

close #19500

* save find to std

* add simple tests

* Apply suggestions from code review

Co-authored-by: konsumlamm <44230978+konsumlamm@users.noreply.github.com>

Co-authored-by: sandytypical <43030857+xflywind@users.noreply.github.com>
Co-authored-by: konsumlamm <44230978+konsumlamm@users.noreply.github.com>
(cherry picked from commit 65c2518d5c)
This commit is contained in:
ringabout
2022-09-29 04:05:41 +08:00
committed by narimiran
parent 5abf259908
commit ce63020110
3 changed files with 53 additions and 25 deletions

View File

@@ -1810,6 +1810,15 @@ func initSkipTable*(a: var SkipTable, sub: string) {.rtl,
for i in 0 ..< m - 1:
a[sub[i]] = m - 1 - i
func initSkipTable*(sub: string): SkipTable {.noinit, rtl,
extern: "nsuInitNewSkipTable".} =
## Returns a new table initialized for `sub`.
##
## See also:
## * `initSkipTable func<#initSkipTable,SkipTable,string>`_
## * `find func<#find,SkipTable,string,string,Natural,int>`_
initSkipTable(result, sub)
func find*(a: SkipTable, s, sub: string, start: Natural = 0, last = 0): int {.
rtl, extern: "nsuFindStrA".} =
## Searches for `sub` in `s` inside range `start..last` using preprocessed
@@ -1842,9 +1851,6 @@ func find*(a: SkipTable, s, sub: string, start: Natural = 0, last = 0): int {.
when not (defined(js) or defined(nimdoc) or defined(nimscript)):
func c_memchr(cstr: pointer, c: char, n: csize_t): pointer {.
importc: "memchr", header: "<string.h>".}
func c_strstr(haystack, needle: cstring): cstring {.
importc: "strstr", header: "<string.h>".}
const hasCStringBuiltin = true
else:
const hasCStringBuiltin = false
@@ -1909,28 +1915,7 @@ func find*(s, sub: string, start: Natural = 0, last = 0): int {.rtl,
if sub.len > s.len - start: return -1
if sub.len == 1: return find(s, sub[0], start, last)
template useSkipTable {.dirty.} =
var a {.noinit.}: SkipTable
initSkipTable(a, sub)
result = find(a, s, sub, start, last)
when not hasCStringBuiltin:
useSkipTable()
else:
when nimvm:
useSkipTable()
else:
when hasCStringBuiltin:
if last == 0 and s.len > start:
let found = c_strstr(s[start].unsafeAddr, sub)
if not found.isNil:
result = cast[ByteAddress](found) -% cast[ByteAddress](s.cstring)
else:
result = -1
else:
useSkipTable()
else:
useSkipTable()
result = find(initSkipTable(sub), s, sub, start, last)
func rfind*(s: string, sub: char, start: Natural = 0, last = -1): int {.rtl,
extern: "nsuRFindChar".} =

View File

@@ -74,3 +74,40 @@ template endsWithImpl*[T: string | cstring](s, suffix: T) =
func cmpNimIdentifier*[T: string | cstring](a, b: T): int =
cmpIgnoreStyleImpl(a, b, true)
func c_memchr(cstr: pointer, c: char, n: csize_t): pointer {.
importc: "memchr", header: "<string.h>".}
func c_strstr(haystack, needle: cstring): cstring {.
importc: "strstr", header: "<string.h>".}
func find*(s: cstring, sub: char, start: Natural = 0, last = 0): int =
## Searches for `sub` in `s` inside the range `start..last` (both ends included).
## If `last` is unspecified, it defaults to `s.high` (the last element).
##
## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
## Otherwise the index returned is relative to `s[0]`, not `start`.
## Use `s[start..last].rfind` for a `start`-origin index.
let last = if last == 0: s.high else: last
let L = last-start+1
if L > 0:
let found = c_memchr(s[start].unsafeAddr, sub, cast[csize_t](L))
if not found.isNil:
return cast[ByteAddress](found) -% cast[ByteAddress](s)
return -1
func find*(s, sub: cstring, start: Natural = 0, last = 0): int =
## Searches for `sub` in `s` inside the range `start..last` (both ends included).
## If `last` is unspecified, it defaults to `s.high` (the last element).
##
## Searching is case-sensitive. If `sub` is not in `s`, -1 is returned.
## Otherwise the index returned is relative to `s[0]`, not `start`.
## Use `s[start..last].find` for a `start`-origin index.
if sub.len > s.len - start: return -1
if sub.len == 1: return find(s, sub[0], start, last)
if last == 0 and s.len > start:
let found = c_strstr(s[start].unsafeAddr, sub)
if not found.isNil:
result = cast[ByteAddress](found) -% cast[ByteAddress](s)
else:
result = -1

View File

@@ -0,0 +1,6 @@
import std/private/strimpl
doAssert find(cstring"Hello Nim", cstring"Nim") == 6
doAssert find(cstring"Hello Nim", cstring"N") == 6
doAssert find(cstring"Hello Nim", cstring"I") == -1
doAssert find(cstring"Hello Nim", cstring"O") == -1