Skip to content

Commit 3733d15

Browse files
committed
Making TokenizerInterface to be more usable for the user's code.
Adding id_to_piece, piece_to_id and is_special_token functionality to TokenizerInterface and the corresponding implementations.
1 parent 30d69b3 commit 3733d15

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

tokenizer.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import bisect
12
import os
23
import sentencepiece as spm
34
import tiktoken
@@ -21,6 +22,15 @@ def bos_id(self):
2122
def eos_id(self):
2223
raise NotImplementedError("This method should be overridden by subclasses.")
2324

25+
def id_to_piece(self, token_id):
26+
raise NotImplementedError("This method should be overridden by subclasses.")
27+
28+
def piece_to_id(self, token_str):
29+
raise NotImplementedError("This method should be overridden by subclasses.")
30+
31+
def is_special_token(self, token_id):
32+
raise NotImplementedError("This method should be overridden by subclasses.")
33+
2434
class SentencePieceWrapper(TokenizerInterface):
2535
def __init__(self, model_path):
2636
super().__init__(model_path)
@@ -38,6 +48,17 @@ def bos_id(self):
3848
def eos_id(self):
3949
return self.processor.eos_id()
4050

51+
def id_to_piece(self, token_id):
52+
return self.processor.id_to_piece(token_id).replace("▁", " ")
53+
54+
def piece_to_id(self, token_str):
55+
return self.processor.piece_to_id(token_str.replace(" ", "▁"))
56+
57+
def is_special_token(self, token_id):
58+
return self.processor.IsControl(token_id) \
59+
or self.processor.IsUnknown(token_id) \
60+
or self.processor.IsUnused(token_id)
61+
4162
class TiktokenWrapper(TokenizerInterface):
4263
"""
4364
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -53,7 +74,7 @@ def __init__(self, model_path):
5374
super().__init__(model_path)
5475
assert os.path.isfile(model_path), str(model_path)
5576
mergeable_ranks = load_tiktoken_bpe(str(model_path))
56-
num_base_tokens = len(mergeable_ranks)
77+
self.num_base_tokens = len(mergeable_ranks)
5778
special_tokens = [
5879
"<|begin_of_text|>",
5980
"<|end_of_text|>",
@@ -70,7 +91,7 @@ def __init__(self, model_path):
7091
for i in range(5, self.num_reserved_special_tokens - 5)
7192
]
7293
self.special_tokens = {
73-
token: num_base_tokens + i for i, token in enumerate(special_tokens)
94+
token: self.num_base_tokens + i for i, token in enumerate(special_tokens)
7495
}
7596
self.model = tiktoken.Encoding(
7697
name=Path(model_path).name,
@@ -94,6 +115,15 @@ def bos_id(self):
94115
def eos_id(self):
95116
return self._eos_id
96117

118+
def id_to_piece(self, token_id):
119+
return self.model.decode([token_id])
120+
121+
def piece_to_id(self, token_str):
122+
return self.model.encode_single_token(token_str)
123+
124+
def is_special_token(self, token_id):
125+
return token_id >= self.num_base_tokens and token_id < self.num_base_tokens + len(self.special_tokens)
126+
97127
def get_tokenizer(tokenizer_model_path, model_name):
98128
"""
99129
Factory function to get the appropriate tokenizer based on the model name.

0 commit comments

Comments
 (0)