Binary search improvements

Modified the algorithm so that the index is either the location of the
element if found or the index at which to insert the element to maintain
sorted order.

Also added some tests to verify the above claim.
This commit is contained in:
Hector
2023-11-25 13:42:28 +00:00
parent cabaac5a68
commit 1db5e1250f
4 changed files with 168 additions and 37 deletions

View File

@@ -117,46 +117,95 @@ linear_search_proc :: proc(array: $A/[]$T, f: proc(T) -> bool) -> (index: int, f
return -1, false
}
/*
Binary search searches the given slice for the given element.
If the slice is not sorted, the returned index is unspecified and meaningless.
If the value is found then the returned int is the index of the matching element.
If there are multiple matches, then any one of the matches could be returned.
If the value is not found then the returned int is the index where a matching
element could be inserted while maintaining sorted order.
# Examples
Looks up a series of four elements. The first is found, with a
uniquely determined position; the second and third are not
found; the fourth could match any position in `[1, 4]`.
```
index: int
found: bool
s := []i32{0, 1, 1, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55}
index, found = slice.binary_search(s, 13)
assert(index == 9 && found == true)
index, found = slice.binary_search(s, 4)
assert(index == 7 && found == false)
index, found = slice.binary_search(s, 100)
assert(index == 13 && found == false)
index, found = slice.binary_search(s, 1)
assert(index >= 1 && index <= 4 && found == true)
```
For slices of more complex types see: binary_search_by
*/
@(require_results)
binary_search :: proc(array: $A/[]$T, key: T) -> (index: int, found: bool)
where intrinsics.type_is_ordered(T) #no_bounds_check {
n := len(array)
switch n {
case 0:
return -1, false
case 1:
if array[0] == key {
return 0, true
}
return -1, false
}
lo, hi := 0, n-1
for array[hi] != array[lo] && key >= array[lo] && key <= array[hi] {
when intrinsics.type_is_ordered_numeric(T) {
// NOTE(bill): This is technically interpolation search
m := lo + int((key - array[lo]) * T(hi - lo) / (array[hi] - array[lo]))
} else {
m := lo + (hi - lo)/2
}
where intrinsics.type_is_ordered(T) #no_bounds_check
{
// I would like to use binary_search_by(array, key, cmp) here, but it doesn't like it:
// Cannot assign value 'cmp' of type 'proc($E, $E) -> Ordering' to 'proc(i32, i32) -> Ordering' in argument
return binary_search_by(array, key, proc(key: T, element: T) -> Ordering {
switch {
case array[m] < key:
lo = m + 1
case key < array[m]:
hi = m - 1
case:
return m, true
case element < key: return .Less
case element > key: return .Greater
case: return .Equal
}
}
if key == array[lo] {
return lo, true
}
return -1, false
})
}
@(require_results)
binary_search_by :: proc(array: $A/[]$T, key: T, f: proc(T, T) -> Ordering) -> (index: int, found: bool)
where intrinsics.type_is_ordered(T) #no_bounds_check
{
// INVARIANTS:
// - 0 <= left <= (left + size = right) <= len(array)
// - f returns .Less for everything in array[:left]
// - f returns .Greater for everything in array[right:]
size := len(array)
left := 0
right := size
for left < right {
mid := left + size / 2;
// Steps to verify this is in-bounds:
// 1. We note that `size` is strictly positive due to the loop condition
// 2. Therefore `size/2 < size`
// 3. Adding `left` to both sides yields `(left + size/2) < (left + size)`
// 4. We know from the invariant that `left + size <= len(array)`
// 5. Therefore `left + size/2 < self.len()`
cmp := f(key, array[mid])
left = mid + 1 if cmp == .Less else left
right = mid if cmp == .Greater else right
switch cmp {
case .Equal: return mid, true
case .Less: left = mid + 1
case .Greater: right = mid
}
size = right - left;
}
return left, false
}
@(require_results)
equal :: proc(a, b: $T/[]$E) -> bool where intrinsics.type_is_comparable(E) {

View File

@@ -1,9 +1,26 @@
ODIN=../../odin
PYTHON=$(shell which python3)
all: download_test_assets image_test compress_test strings_test hash_test crypto_test noise_test encoding_test \
math_test linalg_glsl_math_test filepath_test reflect_test os_exit_test i18n_test match_test c_libc_test net_test \
fmt_test thread_test
all: c_libc_test \
compress_test \
crypto_test \
download_test_assets \
encoding_test \
filepath_test \
fmt_test \
hash_test \
i18n_test \
image_test \
linalg_glsl_math_test \
match_test \
math_test \
net_test \
noise_test \
os_exit_test \
reflect_test \
slice_test \
strings_test \
thread_test
download_test_assets:
$(PYTHON) download_assets.py
@@ -44,6 +61,9 @@ filepath_test:
reflect_test:
$(ODIN) run reflect/test_core_reflect.odin -file -collection:tests=.. -out:test_core_reflect
slice_test:
$(ODIN) run slice/test_core_slice.odin -file -out:test_core_slice
os_exit_test:
$(ODIN) run os/test_core_os_exit.odin -file -out:test_core_os_exit && exit 1 || exit 0

View File

@@ -66,6 +66,11 @@ echo Running core:reflect tests
echo ---
%PATH_TO_ODIN% run reflect %COMMON% %COLLECTION% -out:test_core_reflect.exe || exit /b
echo ---
echo Running core:slice tests
echo ---
%PATH_TO_ODIN% run slice %COMMON% -out:test_core_slice.exe || exit /b
echo ---
echo Running core:text/i18n tests
echo ---

View File

@@ -30,6 +30,7 @@ when ODIN_TEST {
main :: proc() {
t := testing.T{}
test_sort_with_indices(&t)
test_binary_search(&t)
fmt.printf("%v/%v tests successful.\n", TEST_count - TEST_fail, TEST_count)
if TEST_fail > 0 {
@@ -180,3 +181,59 @@ test_sort_by_indices :: proc(t: ^testing.T) {
}
}
}
@test
test_binary_search :: proc(t: ^testing.T) {
index: int
found: bool
test_search :: proc(s: []i32, v: i32) -> (int, bool) {
fmt.printf("Searching for %v in %v\n", v, s)
index, found := slice.binary_search(s, v)
fmt.printf("index: %v\nfound: %v\n", index, found)
return index, found
}
s := []i32{0, 1, 1, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55}
index, found = test_search(s, 13)
assert(index == 9, "Expected index to be 9.")
assert(found == true, "Expected found to be true.")
index, found = test_search(s, 4)
assert(index == 7, "Expected index to be 7.")
assert(found == false, "Expected found to be false.")
index, found = test_search(s, 100)
assert(index == 13, "Expected index to be 13.")
assert(found == false, "Expected found to be false.")
index, found = test_search(s, 1)
assert(index >= 1 && index <= 4, "Expected index to be 1, 2, 3, or 4.")
assert(found == true, "Expected found to be true.")
index, found = test_search(s, -1)
assert(index == 0, "Expected index to be 0.")
assert(found == false, "Expected found to be false.")
a := []i32{}
index, found = test_search(a, 13)
assert(index == 0, "Expected index to be 0.")
assert(found == false, "Expected found to be false.")
b := []i32{1}
index, found = test_search(b, 13)
assert(index == 1, "Expected index to be 1.")
assert(found == false, "Expected found to be false.")
index, found = test_search(b, 1)
assert(index == 0, "Expected index to be 0.")
assert(found == true, "Expected found to be true.")
index, found = test_search(b, 0)
assert(index == 0, "Expected index to be 0.")
assert(found == false, "Expected found to be false.")
}