@@ -21,6 +21,15 @@ def bos_id(self):
21
21
def eos_id (self ):
22
22
raise NotImplementedError ("This method should be overridden by subclasses." )
23
23
24
+ def id_to_piece (self , token_id ):
25
+ raise NotImplementedError ("This method should be overridden by subclasses." )
26
+
27
+ def piece_to_id (self , token_str ):
28
+ raise NotImplementedError ("This method should be overridden by subclasses." )
29
+
30
+ def is_special_token (self , token_id ):
31
+ raise NotImplementedError ("This method should be overridden by subclasses." )
32
+
24
33
class SentencePieceWrapper (TokenizerInterface ):
25
34
def __init__ (self , model_path ):
26
35
super ().__init__ (model_path )
@@ -38,6 +47,17 @@ def bos_id(self):
38
47
def eos_id (self ):
39
48
return self .processor .eos_id ()
40
49
50
+ def id_to_piece (self , token_id ):
51
+ return self .processor .id_to_piece (token_id ).replace ("▁" , " " )
52
+
53
+ def piece_to_id (self , token_str ):
54
+ return self .processor .piece_to_id (token_str .replace (" " , "▁" ))
55
+
56
+ def is_special_token (self , token_id ):
57
+ return self .processor .IsControl (token_id ) \
58
+ or self .processor .IsUnknown (token_id ) \
59
+ or self .processor .IsUnused (token_id )
60
+
41
61
class TiktokenWrapper (TokenizerInterface ):
42
62
"""
43
63
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -53,7 +73,7 @@ def __init__(self, model_path):
53
73
super ().__init__ (model_path )
54
74
assert os .path .isfile (model_path ), str (model_path )
55
75
mergeable_ranks = load_tiktoken_bpe (str (model_path ))
56
- num_base_tokens = len (mergeable_ranks )
76
+ self . num_base_tokens = len (mergeable_ranks )
57
77
special_tokens = [
58
78
"<|begin_of_text|>" ,
59
79
"<|end_of_text|>" ,
@@ -70,7 +90,7 @@ def __init__(self, model_path):
70
90
for i in range (5 , self .num_reserved_special_tokens - 5 )
71
91
]
72
92
self .special_tokens = {
73
- token : num_base_tokens + i for i , token in enumerate (special_tokens )
93
+ token : self . num_base_tokens + i for i , token in enumerate (special_tokens )
74
94
}
75
95
self .model = tiktoken .Encoding (
76
96
name = Path (model_path ).name ,
@@ -94,6 +114,15 @@ def bos_id(self):
94
114
def eos_id (self ):
95
115
return self ._eos_id
96
116
117
+ def id_to_piece (self , token_id ):
118
+ return self .model .decode ([token_id ])
119
+
120
+ def piece_to_id (self , token_str ):
121
+ return self .model .encode_single_token (token_str )
122
+
123
+ def is_special_token (self , token_id ):
124
+ return token_id >= self .num_base_tokens and token_id < self .num_base_tokens + len (self .special_tokens )
125
+
97
126
def get_tokenizer (tokenizer_model_path , model_name ):
98
127
"""
99
128
Factory function to get the appropriate tokenizer based on the model name.
0 commit comments