@@ -7,19 +7,23 @@ use tokenizers::tokenizer::Tokenizer;
7
7
8
8
pub struct TokenizerWrapper {
9
9
tokenizer : Tokenizer ,
10
- encode_ids : Vec < u32 > ,
11
10
decode_str : String ,
12
11
id_to_token_result : String ,
13
12
}
14
13
15
14
pub type Vocab = HashMap < String , u32 > ;
16
15
pub type Merges = Vec < ( String , String ) > ;
17
16
17
+ #[ repr( C ) ]
18
+ pub struct TokenizerEncodeResult {
19
+ token_ids : * mut u32 ,
20
+ len : usize ,
21
+ }
22
+
18
23
impl TokenizerWrapper {
19
24
pub fn from_str ( json : & str ) -> TokenizerWrapper {
20
25
TokenizerWrapper {
21
26
tokenizer : Tokenizer :: from_str ( json) . unwrap ( ) . into ( ) ,
22
- encode_ids : Vec :: new ( ) ,
23
27
decode_str : String :: new ( ) ,
24
28
id_to_token_result : String :: new ( ) ,
25
29
}
@@ -77,16 +81,22 @@ impl TokenizerWrapper {
77
81
. with_decoder ( byte_level) ;
78
82
TokenizerWrapper {
79
83
tokenizer : tokenizer,
80
- encode_ids : Vec :: new ( ) ,
81
84
decode_str : String :: new ( ) ,
82
85
id_to_token_result : String :: new ( ) ,
83
86
}
84
87
}
85
88
86
- pub fn encode ( & mut self , text : & str , add_special_tokens : bool ) {
89
+ pub fn encode ( & mut self , text : & str , add_special_tokens : bool ) -> Vec < u32 > {
87
90
let encoded = self . tokenizer . encode ( text, add_special_tokens) . unwrap ( ) ;
88
- self . encode_ids . resize ( encoded. len ( ) , 0 ) ;
89
- self . encode_ids . copy_from_slice ( encoded. get_ids ( ) ) ;
91
+ return encoded. get_ids ( ) . to_vec ( ) ;
92
+ }
93
+
94
+ pub fn encode_batch ( & mut self , texts : Vec < & str > , add_special_tokens : bool ) -> Vec < Vec < u32 > > {
95
+ let results = self . tokenizer . encode_batch ( texts, add_special_tokens) . unwrap ( )
96
+ . into_iter ( )
97
+ . map ( |encoded| encoded. get_ids ( ) . to_vec ( ) )
98
+ . collect :: < Vec < Vec < u32 > > > ( ) ;
99
+ return results;
90
100
}
91
101
92
102
pub fn decode ( & mut self , ids : & [ u32 ] , skip_special_tokens : bool ) {
@@ -135,22 +145,53 @@ extern "C" fn tokenizers_encode(
135
145
input_cstr : * const u8 ,
136
146
len : usize ,
137
147
add_special_tokens : i32 ,
148
+ out_result : * mut TokenizerEncodeResult ,
138
149
) {
139
150
unsafe {
140
151
let input_data = std:: str:: from_utf8 ( std:: slice:: from_raw_parts ( input_cstr, len) ) . unwrap ( ) ;
141
- ( * handle) . encode ( input_data, add_special_tokens != 0 ) ;
152
+ let encoded = ( * handle) . encode ( input_data, add_special_tokens != 0 ) ;
153
+ let len = encoded. len ( ) ;
154
+ * out_result = TokenizerEncodeResult {
155
+ token_ids : Box :: into_raw ( encoded. into_boxed_slice ( ) ) as * mut u32 ,
156
+ len : len,
157
+ } ;
142
158
}
143
159
}
144
160
145
161
#[ no_mangle]
146
- extern "C" fn tokenizers_get_encode_ids (
162
+ extern "C" fn tokenizers_encode_batch (
147
163
handle : * mut TokenizerWrapper ,
148
- out_data : * mut * mut u32 ,
149
- out_len : * mut usize ,
164
+ input_cstr : * const * const u8 ,
165
+ input_len : * const usize ,
166
+ num_seqs : usize ,
167
+ add_special_tokens : i32 ,
168
+ out_result : * mut TokenizerEncodeResult ,
150
169
) {
151
170
unsafe {
152
- * out_data = ( * handle) . encode_ids . as_mut_ptr ( ) ;
153
- * out_len = ( * handle) . encode_ids . len ( )
171
+ let input_data = ( 0 ..num_seqs)
172
+ . map ( |i| {
173
+ std:: str:: from_utf8 ( std:: slice:: from_raw_parts ( * input_cstr. offset ( i as isize ) , * input_len. offset ( i as isize ) ) ) . unwrap ( )
174
+ } )
175
+ . collect :: < Vec < & str > > ( ) ;
176
+ let encoded_batch = ( * handle) . encode_batch ( input_data, add_special_tokens != 0 ) ;
177
+ for ( i, encoded) in encoded_batch. into_iter ( ) . enumerate ( ) {
178
+ let len = encoded. len ( ) ;
179
+ let result = TokenizerEncodeResult {
180
+ token_ids : Box :: into_raw ( encoded. into_boxed_slice ( ) ) as * mut u32 ,
181
+ len : len,
182
+ } ;
183
+ * out_result. offset ( i as isize ) = result;
184
+ }
185
+ }
186
+ }
187
+
188
+ #[ no_mangle]
189
+ extern "C" fn tokenizers_free_encode_results ( results : * mut TokenizerEncodeResult , num_seqs : usize ) {
190
+ unsafe {
191
+ let slice = std:: slice:: from_raw_parts_mut ( results, num_seqs) ;
192
+ for result in & mut * slice {
193
+ drop ( Box :: from_raw ( std:: slice:: from_raw_parts_mut ( result. token_ids , result. len ) ) ) ;
194
+ }
154
195
}
155
196
}
156
197
0 commit comments