Skip to content

Commit 18e76c7

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
Feat: Add TEI support for ModelBuilder (#4694)
* Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Notebook testing * Notebook testing * Notebook testing * Refactoring * Refactoring * UT * UT * Refactoring * Test coverage * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 828cdc3 commit 18e76c7

File tree

13 files changed

+409
-58
lines changed

13 files changed

+409
-58
lines changed

src/sagemaker/serve/builder/model_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
169169
in order for model builder to build the artifacts correctly (according
170170
to the model server). Possible values for this argument are
171171
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
172-
``TRITON``, and``TGI``.
172+
``TRITON``,``TGI``, and ``TEI``.
173173
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
174174
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
175175
new models without task metadata in the Hub, adding unsupported task types will throw

src/sagemaker/serve/builder/tei_builder.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_get_nb_instance,
2626
)
2727
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
28-
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
28+
from sagemaker.serve.utils.predictors import TeiLocalModePredictor
2929
from sagemaker.serve.utils.types import ModelServer
3030
from sagemaker.serve.mode.function_pointers import Mode
3131
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
@@ -74,16 +74,16 @@ def _prepare_for_mode(self):
7474
def _get_client_translators(self):
7575
"""Placeholder docstring"""
7676

77-
def _set_to_tgi(self):
77+
def _set_to_tei(self):
7878
"""Placeholder docstring"""
79-
if self.model_server != ModelServer.TGI:
79+
if self.model_server != ModelServer.TEI:
8080
messaging = (
8181
"HuggingFace Model ID support on model server: "
8282
f"{self.model_server} is not currently supported. "
83-
f"Defaulting to {ModelServer.TGI}"
83+
f"Defaulting to {ModelServer.TEI}"
8484
)
8585
logger.warning(messaging)
86-
self.model_server = ModelServer.TGI
86+
self.model_server = ModelServer.TEI
8787

8888
def _create_tei_model(self, **kwargs) -> Type[Model]:
8989
"""Placeholder docstring"""
@@ -142,7 +142,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
142142
if self.mode == Mode.LOCAL_CONTAINER:
143143
timeout = kwargs.get("model_data_download_timeout")
144144

145-
predictor = TgiLocalModePredictor(
145+
predictor = TeiLocalModePredictor(
146146
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
147147
)
148148

@@ -180,7 +180,9 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
180180
if "endpoint_logging" not in kwargs:
181181
kwargs["endpoint_logging"] = True
182182

183-
if not self.nb_instance_type and "instance_type" not in kwargs:
183+
if self.nb_instance_type and "instance_type" not in kwargs:
184+
kwargs.update({"instance_type": self.nb_instance_type})
185+
elif not self.nb_instance_type and "instance_type" not in kwargs:
184186
raise ValueError(
185187
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
186188
)
@@ -216,7 +218,7 @@ def _build_for_tei(self):
216218
"""Placeholder docstring"""
217219
self.secret_key = None
218220

219-
self._set_to_tgi()
221+
self._set_to_tei()
220222

221223
self.pysdk_model = self._build_for_hf_tei()
222224
return self.pysdk_model

src/sagemaker/serve/mode/local_container_mode.py

+15
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
2222
from sagemaker.serve.model_server.triton.server import LocalTritonServer
2323
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
24+
from sagemaker.serve.model_server.tei.server import LocalTeiServing
2425
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
2526
from sagemaker.session import Session
2627

@@ -69,6 +70,7 @@ def __init__(
6970
self.container = None
7071
self.secret_key = None
7172
self._ping_container = None
73+
self._invoke_serving = None
7274

7375
def load(self, model_path: str = None):
7476
"""Placeholder docstring"""
@@ -156,6 +158,19 @@ def create_server(
156158
env_vars=env_vars if env_vars else self.env_vars,
157159
)
158160
self._ping_container = self._tensorflow_serving_deep_ping
161+
elif self.model_server == ModelServer.TEI:
162+
tei_serving = LocalTeiServing()
163+
tei_serving._start_tei_serving(
164+
client=self.client,
165+
image=image,
166+
model_path=model_path if model_path else self.model_path,
167+
secret_key=secret_key,
168+
env_vars=env_vars if env_vars else self.env_vars,
169+
)
170+
tei_serving.schema_builder = self.schema_builder
171+
self.container = tei_serving.container
172+
self._ping_container = tei_serving._tei_deep_ping
173+
self._invoke_serving = tei_serving._invoke_tei_serving
159174

160175
# allow some time for container to be ready
161176
time.sleep(10)

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
from typing import Type
88

9+
from sagemaker.serve.model_server.tei.server import SageMakerTeiServing
910
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
1011
from sagemaker.session import Session
1112
from sagemaker.serve.utils.types import ModelServer
@@ -37,6 +38,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe
3738
self.inference_spec = inference_spec
3839
self.model_server = model_server
3940

41+
self._tei_serving = SageMakerTeiServing()
42+
4043
def load(self, model_path: str):
4144
"""Placeholder docstring"""
4245
path = Path(model_path)
@@ -66,8 +69,9 @@ def prepare(
6669
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
6770
) from e
6871

72+
upload_artifacts = None
6973
if self.model_server == ModelServer.TORCHSERVE:
70-
return self._upload_torchserve_artifacts(
74+
upload_artifacts = self._upload_torchserve_artifacts(
7175
model_path=model_path,
7276
sagemaker_session=sagemaker_session,
7377
secret_key=secret_key,
@@ -76,7 +80,7 @@ def prepare(
7680
)
7781

7882
if self.model_server == ModelServer.TRITON:
79-
return self._upload_triton_artifacts(
83+
upload_artifacts = self._upload_triton_artifacts(
8084
model_path=model_path,
8185
sagemaker_session=sagemaker_session,
8286
secret_key=secret_key,
@@ -85,15 +89,15 @@ def prepare(
8589
)
8690

8791
if self.model_server == ModelServer.DJL_SERVING:
88-
return self._upload_djl_artifacts(
92+
upload_artifacts = self._upload_djl_artifacts(
8993
model_path=model_path,
9094
sagemaker_session=sagemaker_session,
9195
s3_model_data_url=s3_model_data_url,
9296
image=image,
9397
)
9498

9599
if self.model_server == ModelServer.TGI:
96-
return self._upload_tgi_artifacts(
100+
upload_artifacts = self._upload_tgi_artifacts(
97101
model_path=model_path,
98102
sagemaker_session=sagemaker_session,
99103
s3_model_data_url=s3_model_data_url,
@@ -102,20 +106,31 @@ def prepare(
102106
)
103107

104108
if self.model_server == ModelServer.MMS:
105-
return self._upload_server_artifacts(
109+
upload_artifacts = self._upload_server_artifacts(
106110
model_path=model_path,
107111
sagemaker_session=sagemaker_session,
108112
s3_model_data_url=s3_model_data_url,
109113
image=image,
110114
)
111115

112116
if self.model_server == ModelServer.TENSORFLOW_SERVING:
113-
return self._upload_tensorflow_serving_artifacts(
117+
upload_artifacts = self._upload_tensorflow_serving_artifacts(
114118
model_path=model_path,
115119
sagemaker_session=sagemaker_session,
116120
secret_key=secret_key,
117121
s3_model_data_url=s3_model_data_url,
118122
image=image,
119123
)
120124

125+
if self.model_server == ModelServer.TEI:
126+
upload_artifacts = self._tei_serving._upload_tei_artifacts(
127+
model_path=model_path,
128+
sagemaker_session=sagemaker_session,
129+
s3_model_data_url=s3_model_data_url,
130+
image=image,
131+
)
132+
133+
if upload_artifacts:
134+
return upload_artifacts
135+
121136
raise ValueError("%s model server is not supported" % self.model_server)

src/sagemaker/serve/model_server/tei/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Module for Local TEI Serving"""
2+
3+
from __future__ import absolute_import
4+
5+
import requests
6+
import logging
7+
from pathlib import Path
8+
from docker.types import DeviceRequest
9+
from sagemaker import Session, fw_utils
10+
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
11+
from sagemaker.base_predictor import PredictorBase
12+
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
13+
from sagemaker.s3 import S3Uploader
14+
from sagemaker.local.utils import get_docker_host
15+
16+
17+
MODE_DIR_BINDING = "/opt/ml/model/"
18+
_SHM_SIZE = "2G"
19+
_DEFAULT_ENV_VARS = {
20+
"TRANSFORMERS_CACHE": "/opt/ml/model/",
21+
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
22+
}
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class LocalTeiServing:
28+
"""LocalTeiServing class"""
29+
30+
def _start_tei_serving(
31+
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
32+
):
33+
"""Starts a local tei serving container.
34+
35+
Args:
36+
client: Docker client
37+
image: Image to use
38+
model_path: Path to the model
39+
secret_key: Secret key to use for authentication
40+
env_vars: Environment variables to set
41+
"""
42+
if env_vars and secret_key:
43+
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key
44+
45+
self.container = client.containers.run(
46+
image,
47+
shm_size=_SHM_SIZE,
48+
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
49+
network_mode="host",
50+
detach=True,
51+
auto_remove=True,
52+
volumes={
53+
Path(model_path).joinpath("code"): {
54+
"bind": MODE_DIR_BINDING,
55+
"mode": "rw",
56+
},
57+
},
58+
environment=_update_env_vars(env_vars),
59+
)
60+
61+
def _invoke_tei_serving(self, request: object, content_type: str, accept: str):
62+
"""Invokes a local tei serving container.
63+
64+
Args:
65+
request: Request to send
66+
content_type: Content type to use
67+
accept: Accept to use
68+
"""
69+
try:
70+
response = requests.post(
71+
f"http://{get_docker_host()}:8080/invocations",
72+
data=request,
73+
headers={"Content-Type": content_type, "Accept": accept},
74+
timeout=600,
75+
)
76+
response.raise_for_status()
77+
return response.content
78+
except Exception as e:
79+
raise Exception("Unable to send request to the local container server") from e
80+
81+
def _tei_deep_ping(self, predictor: PredictorBase):
82+
"""Checks if the local tei serving container is up and running.
83+
84+
If the container is not up and running, it will raise an exception.
85+
"""
86+
response = None
87+
try:
88+
response = predictor.predict(self.schema_builder.sample_input)
89+
return (True, response)
90+
# pylint: disable=broad-except
91+
except Exception as e:
92+
if "422 Client Error: Unprocessable Entity for url" in str(e):
93+
raise LocalModelInvocationException(str(e))
94+
return (False, response)
95+
96+
return (True, response)
97+
98+
99+
class SageMakerTeiServing:
100+
"""SageMakerTeiServing class"""
101+
102+
def _upload_tei_artifacts(
103+
self,
104+
model_path: str,
105+
sagemaker_session: Session,
106+
s3_model_data_url: str = None,
107+
image: str = None,
108+
env_vars: dict = None,
109+
):
110+
"""Uploads the model artifacts to S3.
111+
112+
Args:
113+
model_path: Path to the model
114+
sagemaker_session: SageMaker session
115+
s3_model_data_url: S3 model data URL
116+
image: Image to use
117+
env_vars: Environment variables to set
118+
"""
119+
if s3_model_data_url:
120+
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
121+
else:
122+
bucket, key_prefix = None, None
123+
124+
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
125+
126+
bucket, code_key_prefix = determine_bucket_and_prefix(
127+
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
128+
)
129+
130+
code_dir = Path(model_path).joinpath("code")
131+
132+
s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code")
133+
134+
logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location)
135+
136+
model_data_url = S3Uploader.upload(
137+
str(code_dir),
138+
s3_location,
139+
None,
140+
sagemaker_session,
141+
)
142+
143+
model_data = {
144+
"S3DataSource": {
145+
"CompressionType": "None",
146+
"S3DataType": "S3Prefix",
147+
"S3Uri": model_data_url + "/",
148+
}
149+
}
150+
151+
return (model_data, _update_env_vars(env_vars))
152+
153+
154+
def _update_env_vars(env_vars: dict) -> dict:
155+
"""Placeholder docstring"""
156+
updated_env_vars = {}
157+
updated_env_vars.update(_DEFAULT_ENV_VARS)
158+
if env_vars:
159+
updated_env_vars.update(env_vars)
160+
return updated_env_vars

0 commit comments

Comments
 (0)