Skip to content

Commit 55170e8

Browse files
authored
Added EncodeBatch interface (#33)
1 parent 07376b9 commit 55170e8

File tree

4 files changed

+115
-25
lines changed

4 files changed

+115
-25
lines changed

include/tokenizers_c.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,30 @@ extern "C" {
1616

1717
typedef void* TokenizerHandle;
1818

19+
typedef struct {
20+
int* token_ids;
21+
size_t len;
22+
} TokenizerEncodeResult;
23+
1924
TokenizerHandle tokenizers_new_from_str(const char* json, size_t len);
2025

2126
TokenizerHandle byte_level_bpe_tokenizers_new_from_str(const char* vocab, size_t vocab_len,
2227
const char* merges, size_t merges_len,
2328
const char* added_tokens,
2429
size_t added_tokens_len);
2530

26-
void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token);
31+
void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token, TokenizerEncodeResult* result);
32+
33+
void tokenizers_encode_batch(TokenizerHandle handle, const char** data, size_t* len, size_t num_seqs,
34+
int add_special_token, TokenizerEncodeResult* results);
35+
36+
void tokenizers_free_encode_results(TokenizerEncodeResult* results, size_t num_seqs);
2737

2838
void tokenizers_decode(TokenizerHandle handle, const uint32_t* data, size_t len,
2939
int skip_special_token);
3040

3141
void tokenizers_get_decode_str(TokenizerHandle handle, const char** data, size_t* len);
3242

33-
void tokenizers_get_encode_ids(TokenizerHandle handle, const uint32_t** id_data, size_t* len);
34-
3543
void tokenizers_get_vocab_size(TokenizerHandle handle, size_t* size);
3644

3745
void tokenizers_id_to_token(TokenizerHandle handle, uint32_t id, const char** data, size_t* len);

include/tokenizers_cpp.h

+15
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ class Tokenizer {
2929
*/
3030
virtual std::vector<int32_t> Encode(const std::string& text) = 0;
3131

32+
/*!
33+
* \brief Encode a batch of texts into ids.
34+
* \param texts The input texts.
35+
* \returns The encoded token ids.
36+
*/
37+
virtual std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts) {
38+
// Fall back when the derived class does not implement this function.
39+
std::vector<std::vector<int32_t>> ret;
40+
ret.reserve(texts.size());
41+
for (const auto& text : texts) {
42+
ret.push_back(Encode(text));
43+
}
44+
return ret;
45+
}
46+
3247
/*!
3348
* \brief Decode token ids into text.
3449
* \param text The token ids.

rust/src/lib.rs

+53-12
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,23 @@ use tokenizers::tokenizer::Tokenizer;
77

88
pub struct TokenizerWrapper {
99
tokenizer: Tokenizer,
10-
encode_ids: Vec<u32>,
1110
decode_str: String,
1211
id_to_token_result: String,
1312
}
1413

1514
pub type Vocab = HashMap<String, u32>;
1615
pub type Merges = Vec<(String, String)>;
1716

17+
#[repr(C)]
18+
pub struct TokenizerEncodeResult {
19+
token_ids: *mut u32,
20+
len: usize,
21+
}
22+
1823
impl TokenizerWrapper {
1924
pub fn from_str(json: &str) -> TokenizerWrapper {
2025
TokenizerWrapper {
2126
tokenizer: Tokenizer::from_str(json).unwrap().into(),
22-
encode_ids: Vec::new(),
2327
decode_str: String::new(),
2428
id_to_token_result: String::new(),
2529
}
@@ -77,16 +81,22 @@ impl TokenizerWrapper {
7781
.with_decoder(byte_level);
7882
TokenizerWrapper {
7983
tokenizer: tokenizer,
80-
encode_ids: Vec::new(),
8184
decode_str: String::new(),
8285
id_to_token_result: String::new(),
8386
}
8487
}
8588

86-
pub fn encode(&mut self, text: &str, add_special_tokens: bool) {
89+
pub fn encode(&mut self, text: &str, add_special_tokens: bool) -> Vec<u32> {
8790
let encoded = self.tokenizer.encode(text, add_special_tokens).unwrap();
88-
self.encode_ids.resize(encoded.len(), 0);
89-
self.encode_ids.copy_from_slice(encoded.get_ids());
91+
return encoded.get_ids().to_vec();
92+
}
93+
94+
pub fn encode_batch(&mut self, texts: Vec<&str>, add_special_tokens: bool) -> Vec<Vec<u32>> {
95+
let results = self.tokenizer.encode_batch(texts, add_special_tokens).unwrap()
96+
.into_iter()
97+
.map(|encoded| encoded.get_ids().to_vec())
98+
.collect::<Vec<Vec<u32>>>();
99+
return results;
90100
}
91101

92102
pub fn decode(&mut self, ids: &[u32], skip_special_tokens: bool) {
@@ -135,22 +145,53 @@ extern "C" fn tokenizers_encode(
135145
input_cstr: *const u8,
136146
len: usize,
137147
add_special_tokens: i32,
148+
out_result: *mut TokenizerEncodeResult,
138149
) {
139150
unsafe {
140151
let input_data = std::str::from_utf8(std::slice::from_raw_parts(input_cstr, len)).unwrap();
141-
(*handle).encode(input_data, add_special_tokens != 0);
152+
let encoded = (*handle).encode(input_data, add_special_tokens != 0);
153+
let len = encoded.len();
154+
*out_result = TokenizerEncodeResult {
155+
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
156+
len: len,
157+
};
142158
}
143159
}
144160

145161
#[no_mangle]
146-
extern "C" fn tokenizers_get_encode_ids(
162+
extern "C" fn tokenizers_encode_batch(
147163
handle: *mut TokenizerWrapper,
148-
out_data: *mut *mut u32,
149-
out_len: *mut usize,
164+
input_cstr: *const *const u8,
165+
input_len: *const usize,
166+
num_seqs: usize,
167+
add_special_tokens: i32,
168+
out_result: *mut TokenizerEncodeResult,
150169
) {
151170
unsafe {
152-
*out_data = (*handle).encode_ids.as_mut_ptr();
153-
*out_len = (*handle).encode_ids.len()
171+
let input_data = (0..num_seqs)
172+
.map(|i| {
173+
std::str::from_utf8(std::slice::from_raw_parts(*input_cstr.offset(i as isize), *input_len.offset(i as isize))).unwrap()
174+
})
175+
.collect::<Vec<&str>>();
176+
let encoded_batch = (*handle).encode_batch(input_data, add_special_tokens != 0);
177+
for (i, encoded) in encoded_batch.into_iter().enumerate() {
178+
let len = encoded.len();
179+
let result = TokenizerEncodeResult {
180+
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
181+
len: len,
182+
};
183+
*out_result.offset(i as isize) = result;
184+
}
185+
}
186+
}
187+
188+
#[no_mangle]
189+
extern "C" fn tokenizers_free_encode_results(results: *mut TokenizerEncodeResult, num_seqs: usize) {
190+
unsafe {
191+
let slice = std::slice::from_raw_parts_mut(results, num_seqs);
192+
for result in &mut *slice {
193+
drop(Box::from_raw(std::slice::from_raw_parts_mut(result.token_ids, result.len)));
194+
}
154195
}
155196
}
156197

src/huggingface_tokenizer.cc

+36-10
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,44 @@ class HFTokenizer : public Tokenizer {
2828

2929
// use i32 to be consistent with sentencepiece
3030
std::vector<int32_t> Encode(const std::string& text, bool add_special_tokens) {
31-
tokenizers_encode(handle_, text.data(), text.length(), static_cast<int>(add_special_tokens));
32-
const uint32_t* data;
33-
size_t len;
34-
tokenizers_get_encode_ids(handle_, &data, &len);
35-
const int32_t* data_i32 = reinterpret_cast<const int32_t*>(data);
36-
auto res = std::vector<int32_t>(data_i32, data_i32 + len);
37-
return res;
31+
TokenizerEncodeResult result;
32+
tokenizers_encode(handle_, text.data(), text.length(), static_cast<int>(add_special_tokens),
33+
&result);
34+
std::vector<int32_t> ret(result.token_ids, result.token_ids + result.len);
35+
tokenizers_free_encode_results(&result, 1);
36+
return ret;
3837
}
3938

40-
// use i32 to be consistent with sentencepiece
41-
std::vector<int32_t> Encode(const std::string& text) final {
42-
return Encode(text, false);
39+
// use i32 to be consistent with sentencepiece
40+
std::vector<int32_t> Encode(const std::string& text) final {
41+
return Encode(text, false);
42+
}
43+
44+
std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts, bool add_special_tokens) final {
45+
std::vector<const char*> texts_raw;
46+
std::vector<size_t> seq_lens;
47+
size_t num_seqs = texts.size();
48+
texts_raw.reserve(num_seqs);
49+
seq_lens.reserve(num_seqs);
50+
for (const auto& text : texts) {
51+
texts_raw.push_back(text.data());
52+
seq_lens.push_back(text.length());
53+
}
54+
std::vector<TokenizerEncodeResult> results(num_seqs);
55+
tokenizers_encode_batch(handle_, texts_raw.data(), seq_lens.data(), texts.size(),
56+
static_cast<int>(add_special_tokens), results.data());
57+
std::vector<std::vector<int32_t>> ret;
58+
ret.reserve(texts.size());
59+
for (size_t i = 0; i < texts.size(); ++i) {
60+
ret.push_back(
61+
std::vector<int32_t>(results[i].token_ids, results[i].token_ids + results[i].len));
62+
}
63+
tokenizers_free_encode_results(results.data(), texts.size());
64+
return ret;
65+
}
66+
67+
std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts) final {
68+
return EncodeBatch(texts, false);
4369
}
4470

4571
// use i32 to be consistent with sentencepiece

0 commit comments

Comments
 (0)