-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathstring_utils.mojo
146 lines (124 loc) · 4.4 KB
/
string_utils.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from algorithm.functional import vectorize
from sys.info import simdwidthof
from sys.intrinsics import compressed_store
# from math import iota, reduce_bit_count, any_true
from math import iota
from memory import stack_allocation
from time import now
from collections.vector import InlinedFixedVector
alias simd_width_i8 = simdwidthof[DType.int8]()
fn vectorize_and_exit[simd_width: Int, workgroup_function: fn[i: Int](Int) capturing -> Bool](size: Int):
var loops = size // simd_width
for i in range(loops):
if workgroup_function[simd_width](i * simd_width):
return
var rest = size & (simd_width - 1)
@parameter
if simd_width >= 64:
if rest >= 32:
if workgroup_function[32](size - rest):
return
rest -= 32
@parameter
if simd_width >= 32:
if rest >= 16:
if workgroup_function[16](size - rest):
return
rest -= 16
@parameter
if simd_width >= 16:
if rest >= 8:
if workgroup_function[8](size - rest):
return
rest -= 8
@parameter
if simd_width >= 8:
if rest >= 4:
if workgroup_function[4](size - rest):
return
rest -= 4
@parameter
if simd_width >= 4:
if rest >= 2:
if workgroup_function[2](size - rest):
return
rest -= 2
if rest == 1:
_= workgroup_function[1](size - rest)
fn find_indices(s: String, c: String) -> List[UInt64]:
var size = len(s)
var result = List[UInt64]()
var char = UInt8(ord(c))
var p = DTypePointer(s.unsafe_ptr())
@parameter
fn find[simd_width: Int](offset: Int):
@parameter
if simd_width == 1:
if p.offset(offset).load() == char:
return result.append(offset)
else:
var chunk = p.load[width=simd_width](offset)
var occurrence = chunk == char
var offsets = iota[DType.uint64, simd_width]() + offset
var occurrence_count = occurrence.reduce_bit_count()
var current_len = len(result)
result.reserve(current_len + occurrence_count)
result.resize(current_len + occurrence_count, 0)
compressed_store(offsets, DTypePointer[DType.uint64](result.data).offset(current_len), occurrence)
vectorize[find, simd_width_i8](size)
return result
fn occurrence_count(s: String, *c: String) -> Int:
var size = len(s)
var result = 0
var chars = InlinedFixedVector[UInt8](len(c))
for i in range(len(c)):
chars.append(UInt8(ord(c[i])))
var p = DTypePointer(s.unsafe_ptr())
@parameter
fn find[simd_width: Int](offset: Int):
@parameter
if simd_width == 1:
for i in range(len(chars)):
var char = chars[i]
if p.offset(offset).load() == char:
result += 1
return
else:
var chunk = p.load[width=simd_width](offset)
var occurrence = SIMD[DType.bool, simd_width](False)
for i in range(len(chars)):
occurrence |= chunk == chars[i]
var occurrence_count = occurrence.reduce_bit_count()
result += occurrence_count
vectorize[find, simd_width_i8](size)
return result
fn contains_any_of(s: String, *c: String) -> Bool:
var size = len(s)
# var c_list: VariadicListMem[String] = c
var chars = InlinedFixedVector[UInt8](len(c))
for i in range(len(c)):
chars.append(UInt8(ord(c[i])))
var p = DTypePointer(s.unsafe_ptr())
var flag = False
@parameter
fn find[simd_width: Int](i: Int) -> Bool:
var chunk = p.load[width=simd_width]()
p = p.offset(simd_width)
for i in range(len(chars)):
var occurrence = chunk == chars[i]
if occurrence.reduce_or():
flag = True
return flag
return False
vectorize_and_exit[simd_width_i8, find](size)
return flag
@always_inline
fn string_from_pointer(p: DTypePointer[DType.uint8], length: Int) -> String:
# Since Mojo 0.5.0 the pointer needs to provide a 0 terminated byte string
p.store(length - 1, 0)
return String(p, length)
fn print_v(v: List[UInt64]):
print("(" + str(len(v)) + ")[")
for i in range(len(v)):
var end = ", " if i < len(v) - 1 else "]\n"
print(v[i], end=end)