Skip to content

Commit d428e74

Browse files
[Feature] Support Tenstorrent's Wormhole accelerators #2573 (#2574)
1 parent c5f207d commit d428e74

File tree

5 files changed

+129
-6
lines changed

5 files changed

+129
-6
lines changed

runner/internal/shim/docker.go

+22
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ func (d *DockerRunner) restoreStateFromContainers(ctx context.Context) error {
151151
gpuIDs = append(gpuIDs, device.PathOnHost)
152152
}
153153
}
154+
case host.GpuVendorTenstorrent:
155+
for _, device := range containerFull.HostConfig.Resources.Devices {
156+
if strings.HasPrefix(device.PathOnHost, "/dev/tenstorrent/") {
157+
// Extract the device ID from the path
158+
deviceID := strings.TrimPrefix(device.PathOnHost, "/dev/tenstorrent/")
159+
gpuIDs = append(gpuIDs, deviceID)
160+
}
161+
}
154162
case host.GpuVendorIntel:
155163
for _, envVar := range containerFull.Config.Env {
156164
if indices, found := strings.CutPrefix(envVar, "HABANA_VISIBLE_DEVICES="); found {
@@ -1009,6 +1017,7 @@ func configureGpuDevices(hostConfig *container.HostConfig, gpuDevices []GPUDevic
10091017
func configureGpus(config *container.Config, hostConfig *container.HostConfig, vendor host.GpuVendor, ids []string) {
10101018
// NVIDIA: ids are identifiers reported by nvidia-smi, GPU-<UUID> strings
10111019
// AMD: ids are DRI render node paths, e.g., /dev/dri/renderD128
1020+
// Tenstorrent: ids are device indices to be used with /dev/tenstorrent/<id>
10121021
switch vendor {
10131022
case host.GpuVendorNvidia:
10141023
hostConfig.Resources.DeviceRequests = append(
@@ -1051,6 +1060,19 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v
10511060
// --security-opt=seccomp=unconfined
10521061
hostConfig.SecurityOpt = append(hostConfig.SecurityOpt, "seccomp=unconfined")
10531062
// TODO: in addition, for non-root user, --group-add=video, and possibly --group-add=render, are required.
1063+
case host.GpuVendorTenstorrent:
1064+
// For Tenstorrent, simply add each device
1065+
for _, id := range ids {
1066+
devicePath := fmt.Sprintf("/dev/tenstorrent/%s", id)
1067+
hostConfig.Resources.Devices = append(
1068+
hostConfig.Resources.Devices,
1069+
container.DeviceMapping{
1070+
PathOnHost: devicePath,
1071+
PathInContainer: devicePath,
1072+
CgroupPermissions: "rwm",
1073+
},
1074+
)
1075+
}
10541076
case host.GpuVendorIntel:
10551077
// All options are listed here:
10561078
// https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html

runner/internal/shim/host/gpu.go

+91-4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@ import (
1818

1919
const amdSmiImage = "un1def/amd-smi:6.2.2-0"
2020

21+
const ttSmiImage = "dstackai/tt-smi:latest"
22+
2123
type GpuVendor string
2224

2325
const (
24-
GpuVendorNone GpuVendor = "none"
25-
GpuVendorNvidia GpuVendor = "nvidia"
26-
GpuVendorAmd GpuVendor = "amd"
27-
GpuVendorIntel GpuVendor = "intel"
26+
GpuVendorNone GpuVendor = "none"
27+
GpuVendorNvidia GpuVendor = "nvidia"
28+
GpuVendorAmd GpuVendor = "amd"
29+
GpuVendorIntel GpuVendor = "intel"
30+
GpuVendorTenstorrent GpuVendor = "tenstorrent"
2831
)
2932

3033
type GpuInfo struct {
@@ -57,6 +60,9 @@ func GetGpuVendor() GpuVendor {
5760
if _, err := os.Stat("/dev/accel"); !errors.Is(err, os.ErrNotExist) {
5861
return GpuVendorIntel
5962
}
63+
if _, err := os.Stat("/dev/tenstorrent"); !errors.Is(err, os.ErrNotExist) {
64+
return GpuVendorTenstorrent
65+
}
6066
return GpuVendorNone
6167
}
6268

@@ -68,6 +74,8 @@ func GetGpuInfo(ctx context.Context) []GpuInfo {
6874
return getAmdGpuInfo(ctx)
6975
case GpuVendorIntel:
7076
return getIntelGpuInfo(ctx)
77+
case GpuVendorTenstorrent:
78+
return getTenstorrentGpuInfo(ctx)
7179
case GpuVendorNone:
7280
return []GpuInfo{}
7381
}
@@ -195,6 +203,85 @@ func getAmdGpuInfo(ctx context.Context) []GpuInfo {
195203
return gpus
196204
}
197205

206+
type ttSmiSnapshot struct {
207+
DeviceInfo []ttDeviceInfo `json:"device_info"`
208+
}
209+
210+
type ttDeviceInfo struct {
211+
BoardInfo ttBoardInfo `json:"board_info"`
212+
}
213+
214+
type ttBoardInfo struct {
215+
BoardType string `json:"board_type"`
216+
BusID string `json:"bus_id"`
217+
}
218+
219+
func getTenstorrentGpuInfo(ctx context.Context) []GpuInfo {
220+
gpus := []GpuInfo{}
221+
222+
cmd := execute.ExecTask{
223+
Command: "docker",
224+
Args: []string{
225+
"run",
226+
"--rm",
227+
"--device", "/dev/tenstorrent",
228+
ttSmiImage,
229+
"-s",
230+
},
231+
StreamStdio: false,
232+
}
233+
res, err := cmd.Execute(ctx)
234+
if err != nil {
235+
log.Error(ctx, "failed to execute tt-smi", "err", err)
236+
return gpus
237+
}
238+
if res.ExitCode != 0 {
239+
log.Error(
240+
ctx, "failed to execute tt-smi",
241+
"exitcode", res.ExitCode, "stdout", res.Stdout, "stderr", res.Stderr,
242+
)
243+
return gpus
244+
}
245+
246+
var ttSmiSnapshot ttSmiSnapshot
247+
if err := json.Unmarshal([]byte(res.Stdout), &ttSmiSnapshot); err != nil {
248+
log.Error(ctx, "cannot read tt-smi json", "err", err)
249+
log.Debug(ctx, "tt-smi output", "stdout", res.Stdout)
250+
return gpus
251+
}
252+
253+
for i, device := range ttSmiSnapshot.DeviceInfo {
254+
// Extract board type without R/L suffix
255+
boardType := strings.TrimSpace(device.BoardInfo.BoardType)
256+
name := boardType
257+
258+
// Remove " R" or " L" suffix if present
259+
if strings.HasSuffix(boardType, " R") {
260+
name = boardType[:len(boardType)-2]
261+
} else if strings.HasSuffix(boardType, " L") {
262+
name = boardType[:len(boardType)-2]
263+
}
264+
265+
// Determine VRAM based on board type
266+
vram := 0
267+
if strings.HasPrefix(name, "n150") {
268+
vram = 12 * 1024 // 12GB in MiB
269+
} else if strings.HasPrefix(name, "n300") {
270+
vram = 24 * 1024 // 24GB in MiB
271+
}
272+
273+
gpus = append(gpus, GpuInfo{
274+
Vendor: GpuVendorTenstorrent,
275+
Name: name,
276+
Vram: vram,
277+
ID: device.BoardInfo.BusID,
278+
Index: strconv.Itoa(i),
279+
})
280+
}
281+
282+
return gpus
283+
}
284+
198285
func getAmdRenderNodePath(bdf string) (string, error) {
199286
// amd-smi uses extended BDF Notation with domain: Domain:Bus:Device.Function, e.g., 0000:5f:00.0
200287
// udev creates /dev/dri/by-path/pci-<BDF>-render -> ../renderD<N> symlinks

runner/internal/shim/resources.go

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ func NewGpuLock(gpus []host.GpuInfo) (*GpuLock, error) {
4242
resourceID = gpu.ID
4343
case host.GpuVendorAmd:
4444
resourceID = gpu.RenderNodePath
45+
case host.GpuVendorTenstorrent:
46+
resourceID = gpu.Index
4547
case host.GpuVendorIntel:
4648
resourceID = gpu.Index
4749
case host.GpuVendorNone:

src/dstack/_internal/cli/services/configurators/run.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
_KNOWN_AMD_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_AMD_GPUS}
5353
_KNOWN_NVIDIA_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_NVIDIA_GPUS}
5454
_KNOWN_TPU_VERSIONS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TPUS}
55-
55+
_KNOWN_TENSTORRENT_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TENSTORRENT_ACCELERATORS}
5656
_BIND_ADDRESS_ARG = "bind_address"
5757

5858
logger = get_logger(__name__)
@@ -350,6 +350,7 @@ def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
350350
if gpu_spec.count.max == 0:
351351
return
352352
has_amd_gpu: bool
353+
has_tt_gpu: bool
353354
vendor = gpu_spec.vendor
354355
if vendor is None:
355356
names = gpu_spec.name
@@ -362,6 +363,8 @@ def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
362363
vendors.add(gpuhunt.AcceleratorVendor.NVIDIA)
363364
elif name in _KNOWN_AMD_GPUS:
364365
vendors.add(gpuhunt.AcceleratorVendor.AMD)
366+
elif name in _KNOWN_TENSTORRENT_GPUS:
367+
vendors.add(gpuhunt.AcceleratorVendor.TENSTORRENT)
365368
else:
366369
maybe_tpu_version, _, maybe_tpu_cores = name.partition("-")
367370
if maybe_tpu_version in _KNOWN_TPU_VERSIONS and maybe_tpu_cores.isdigit():
@@ -380,15 +383,22 @@ def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
380383
# to execute a run on an instance with an AMD accelerator with a default
381384
# CUDA image, not a big deal.
382385
has_amd_gpu = gpuhunt.AcceleratorVendor.AMD in vendors
386+
has_tt_gpu = gpuhunt.AcceleratorVendor.TENSTORRENT in vendors
383387
else:
384388
# If neither gpu.vendor nor gpu.name is set, assume Nvidia.
385389
vendor = gpuhunt.AcceleratorVendor.NVIDIA
386390
has_amd_gpu = False
391+
has_tt_gpu = False
387392
gpu_spec.vendor = vendor
388393
else:
389394
has_amd_gpu = vendor == gpuhunt.AcceleratorVendor.AMD
395+
has_tt_gpu = vendor == gpuhunt.AcceleratorVendor.TENSTORRENT
390396
if has_amd_gpu and conf.image is None:
391-
raise ConfigurationError("`image` is required if `resources.gpu.vendor` is AMD.")
397+
raise ConfigurationError("`image` is required if `resources.gpu.vendor` is `amd`")
398+
if has_tt_gpu and conf.image is None:
399+
raise ConfigurationError(
400+
"`image` is required if `resources.gpu.vendor` is `tenstorrent`"
401+
)
392402

393403

394404
class RunWithPortsConfigurator(BaseRunConfigurator):

src/dstack/_internal/core/models/resources.py

+2
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ def _vendor_from_string(cls, v: str) -> gpuhunt.AcceleratorVendor:
246246
v = v.lower()
247247
if v == "tpu":
248248
return gpuhunt.AcceleratorVendor.GOOGLE
249+
if v == "tt":
250+
return gpuhunt.AcceleratorVendor.TENSTORRENT
249251
return gpuhunt.AcceleratorVendor.cast(v)
250252

251253

0 commit comments

Comments
 (0)