Skip to content

Commit ea1ad22

Browse files
committed
Making TokenizerInterface more usable for the user's code.
Adding id_to_piece, piece_to_id and is_special_token functionality to TokenizerInterface and the corresponding implementations.
1 parent 30d69b3 commit ea1ad22

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

tokenizer.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ def bos_id(self):
2121
def eos_id(self):
2222
raise NotImplementedError("This method should be overridden by subclasses.")
2323

24+
def id_to_piece(self, token_id):
25+
raise NotImplementedError("This method should be overridden by subclasses.")
26+
27+
def piece_to_id(self, token_str):
28+
raise NotImplementedError("This method should be overridden by subclasses.")
29+
30+
def is_special_token(self, token_id):
31+
raise NotImplementedError("This method should be overridden by subclasses.")
32+
2433
class SentencePieceWrapper(TokenizerInterface):
2534
def __init__(self, model_path):
2635
super().__init__(model_path)
@@ -38,6 +47,17 @@ def bos_id(self):
3847
def eos_id(self):
3948
return self.processor.eos_id()
4049

50+
def id_to_piece(self, token_id):
51+
return self.processor.id_to_piece(token_id).replace("▁", " ")
52+
53+
def piece_to_id(self, token_str):
54+
return self.processor.piece_to_id(token_str.replace(" ", "▁"))
55+
56+
def is_special_token(self, token_id):
57+
return self.processor.IsControl(token_id) \
58+
or self.processor.IsUnknown(token_id) \
59+
or self.processor.IsUnused(token_id)
60+
4161
class TiktokenWrapper(TokenizerInterface):
4262
"""
4363
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -53,7 +73,7 @@ def __init__(self, model_path):
5373
super().__init__(model_path)
5474
assert os.path.isfile(model_path), str(model_path)
5575
mergeable_ranks = load_tiktoken_bpe(str(model_path))
56-
num_base_tokens = len(mergeable_ranks)
76+
self.num_base_tokens = len(mergeable_ranks)
5777
special_tokens = [
5878
"<|begin_of_text|>",
5979
"<|end_of_text|>",
@@ -70,7 +90,7 @@ def __init__(self, model_path):
7090
for i in range(5, self.num_reserved_special_tokens - 5)
7191
]
7292
self.special_tokens = {
73-
token: num_base_tokens + i for i, token in enumerate(special_tokens)
93+
token: self.num_base_tokens + i for i, token in enumerate(special_tokens)
7494
}
7595
self.model = tiktoken.Encoding(
7696
name=Path(model_path).name,
@@ -94,6 +114,15 @@ def bos_id(self):
94114
def eos_id(self):
95115
return self._eos_id
96116

117+
def id_to_piece(self, token_id):
118+
return self.model.decode([token_id])
119+
120+
def piece_to_id(self, token_str):
121+
return self.model.encode_single_token(token_str)
122+
123+
def is_special_token(self, token_id):
124+
return token_id >= self.num_base_tokens and token_id < self.num_base_tokens + len(self.special_tokens)
125+
97126
def get_tokenizer(tokenizer_model_path, model_name):
98127
"""
99128
Factory function to get the appropriate tokenizer based on the model name.

0 commit comments

Comments
 (0)