1
+ import bisect
1
2
import os
2
3
import sentencepiece as spm
3
4
import tiktoken
@@ -21,6 +22,15 @@ def bos_id(self):
21
22
def eos_id (self ):
22
23
raise NotImplementedError ("This method should be overridden by subclasses." )
23
24
25
+ def id_to_piece (self , token_id ):
26
+ raise NotImplementedError ("This method should be overridden by subclasses." )
27
+
28
+ def piece_to_id (self , token_str ):
29
+ raise NotImplementedError ("This method should be overridden by subclasses." )
30
+
31
+ def is_special_token (self , token_id ):
32
+ raise NotImplementedError ("This method should be overridden by subclasses." )
33
+
24
34
class SentencePieceWrapper (TokenizerInterface ):
25
35
def __init__ (self , model_path ):
26
36
super ().__init__ (model_path )
@@ -38,6 +48,17 @@ def bos_id(self):
38
48
def eos_id (self ):
39
49
return self .processor .eos_id ()
40
50
51
+ def id_to_piece (self , token_id ):
52
+ return self .processor .id_to_piece (token_id ).replace ("▁" , " " )
53
+
54
+ def piece_to_id (self , token_str ):
55
+ return self .processor .piece_to_id (token_str .replace (" " , "▁" ))
56
+
57
+ def is_special_token (self , token_id ):
58
+ return self .processor .IsControl (token_id ) \
59
+ or self .processor .IsUnknown (token_id ) \
60
+ or self .processor .IsUnused (token_id )
61
+
41
62
class TiktokenWrapper (TokenizerInterface ):
42
63
"""
43
64
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -53,7 +74,7 @@ def __init__(self, model_path):
53
74
super ().__init__ (model_path )
54
75
assert os .path .isfile (model_path ), str (model_path )
55
76
mergeable_ranks = load_tiktoken_bpe (str (model_path ))
56
- num_base_tokens = len (mergeable_ranks )
77
+ self . num_base_tokens = len (mergeable_ranks )
57
78
special_tokens = [
58
79
"<|begin_of_text|>" ,
59
80
"<|end_of_text|>" ,
@@ -70,7 +91,7 @@ def __init__(self, model_path):
70
91
for i in range (5 , self .num_reserved_special_tokens - 5 )
71
92
]
72
93
self .special_tokens = {
73
- token : num_base_tokens + i for i , token in enumerate (special_tokens )
94
+ token : self . num_base_tokens + i for i , token in enumerate (special_tokens )
74
95
}
75
96
self .model = tiktoken .Encoding (
76
97
name = Path (model_path ).name ,
@@ -94,6 +115,15 @@ def bos_id(self):
94
115
def eos_id (self ):
95
116
return self ._eos_id
96
117
118
+ def id_to_piece (self , token_id ):
119
+ return self .model .decode ([token_id ])
120
+
121
+ def piece_to_id (self , token_str ):
122
+ return self .model .encode_single_token (token_str )
123
+
124
+ def is_special_token (self , token_id ):
125
+ return token_id >= self .num_base_tokens and token_id < self .num_base_tokens + len (self .special_tokens )
126
+
97
127
def get_tokenizer (tokenizer_model_path , model_name ):
98
128
"""
99
129
Factory function to get the appropriate tokenizer based on the model name.
0 commit comments