Skip to content
2 changes: 1 addition & 1 deletion gateway/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
]

[project.optional-dependencies]
sglang = ["sglang-router==0.2.1"]
sglang = ["sglang-router==0.3.2"]

[tool.setuptools.package-data]
"dstack.gateway" = [
Expand Down
10 changes: 5 additions & 5 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import RouterType
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import (
Volume,
Expand Down Expand Up @@ -924,7 +924,7 @@ def get_run_shim_script(
]


def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str:
def get_gateway_user_data(authorized_key: str, router: Optional[RouterType] = None) -> str:
return get_cloud_config(
package_update=True,
packages=[
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def get_latest_runner_build() -> Optional[str]:
return None


def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str:
def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> str:
channel = "release" if settings.DSTACK_RELEASE else "stgn"
base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}"
if build == "latest":
Expand All @@ -1045,11 +1045,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non
wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
# Build package spec with extras if router is specified
if router:
return f"dstack-gateway[{router.type}] @ {wheel}"
return f"dstack-gateway[{router.value}] @ {wheel}"
return f"dstack-gateway @ {wheel}"


def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]:
def get_dstack_gateway_commands(router: Optional[RouterType] = None) -> List[str]:
build = get_dstack_runner_version() or "latest"
gateway_package = get_dstack_gateway_wheel(build, router)
return [
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)
from dstack._internal.core.models.placement import PlacementGroup
from dstack._internal.core.models.resources import CPUSpec, GPUSpec
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import RouterType
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import Volume
from dstack._internal.utils.common import get_or_error
Expand Down Expand Up @@ -864,7 +864,7 @@ def _wait_for_load_balancer_address(


def _get_gateway_commands(
authorized_keys: List[str], router: Optional[AnyRouterConfig] = None
authorized_keys: List[str], router: Optional[RouterType] = None
) -> List[str]:
authorized_keys_content = "\n".join(authorized_keys).strip()
gateway_commands = " && ".join(get_dstack_gateway_commands(router=router))
Expand Down
6 changes: 2 additions & 4 deletions src/dstack/_internal/core/compatibility/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def _get_gateway_configuration_excludes(
) -> IncludeExcludeDictType:
configuration_excludes: IncludeExcludeDictType = {}

# Add excludes like this:
#
# if configuration.tags is None:
# configuration_excludes["tags"] = True
if configuration.router is None:
configuration_excludes["router"] = True

return configuration_excludes
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/compatibility/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
# Servers prior to 0.20.8 do not support probes=None
configuration_excludes["probes"] = True

router = run_spec.configuration.router
if router is None:
configuration_excludes["router"] = True
elif router.pd_disaggregation is False:
configuration_excludes["router"] = {"pd_disaggregation": True}

if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
if profile_excludes:
Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
parse_off_duration,
)
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
Expand Down Expand Up @@ -888,6 +889,14 @@ class ServiceConfigurationParams(CoreModel):
)
),
] = None
router: Annotated[
Optional[AnyRouterConfig],
Field(
description=(
"Router configuration for the service. Requires a gateway with matching router enabled. "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit)

Suggested change
"Router configuration for the service. Requires a gateway with matching router enabled. "
"Router configuration for the service. Requires a gateway with matching router enabled"

),
),
] = None

@validator("port")
def convert_port(cls, v) -> PortMapping:
Expand Down
8 changes: 4 additions & 4 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import RouterType
from dstack._internal.utils.tags import tags_validator


Expand Down Expand Up @@ -63,8 +63,8 @@ class GatewayConfiguration(CoreModel):
),
] = None
router: Annotated[
Optional[AnyRouterConfig],
Field(description="The router configuration"),
Optional[RouterType],
Comment thread
jvstme marked this conversation as resolved.
Outdated
Field(description="The router type enabled on this gateway. E.g. 'sglang'."),
] = None
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
Expand Down Expand Up @@ -134,7 +134,7 @@ class GatewayComputeConfiguration(CoreModel):
ssh_key_pub: str
certificate: Optional[AnyGatewayCertificate] = None
tags: Optional[Dict[str, str]] = None
router: Optional[AnyRouterConfig] = None
router: Optional[RouterType] = None


class GatewayProvisioningData(CoreModel):
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class SGLangRouterConfig(CoreModel):
description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`"
),
] = "cache_aware"
pd_disaggregation: Annotated[
Comment thread
jvstme marked this conversation as resolved.
Comment thread
jvstme marked this conversation as resolved.
bool,
Field(description="Enable PD disaggregation mode for the SGLang router"),
] = False


AnyRouterConfig = SGLangRouterConfig
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def register_replica(
ssh_proxy=body.ssh_proxy,
ssh_head_proxy=body.ssh_head_proxy,
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
internal_ip=body.internal_ip,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class RegisterReplicaRequest(BaseModel):
ssh_proxy: Optional[SSHConnectionParams]
ssh_head_proxy: Optional[SSHConnectionParams]
ssh_head_proxy_private_key: Optional[str]
internal_ip: Optional[str] = None


class RegisterEntrypointRequest(BaseModel):
Expand Down
76 changes: 66 additions & 10 deletions src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import subprocess
import sys
import time
import urllib.parse
from typing import List, Optional

import httpx
Expand Down Expand Up @@ -68,6 +67,8 @@ def start(self) -> None:
"--policy",
self.config.policy,
]
if self.config.pd_disaggregation:
cmd.append("--pd-disaggregation")

subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

Expand Down Expand Up @@ -174,7 +175,7 @@ def update_replicas(self, replica_urls: List[str]) -> None:

# Add workers
for worker_url in sorted(workers_to_add):
success = self._add_worker_to_router(worker_url)
success = self._register_worker(worker_url)
if not success:
logger.warning("Failed to add worker %s, continuing with others", worker_url)

Expand All @@ -197,9 +198,16 @@ def _get_router_workers(self) -> List[dict]:
logger.exception("Error getting sglang router workers")
return []

def _add_worker_to_router(self, worker_url: str) -> bool:
def _add_worker_to_router(
self,
url: str,
worker_type: str = "regular",
bootstrap_port: Optional[int] = None,
) -> bool:
try:
payload = {"url": worker_url, "worker_type": "regular"}
payload: dict = {"url": url, "worker_type": worker_type}
if bootstrap_port is not None:
payload["bootstrap_port"] = bootstrap_port
with httpx.Client(timeout=5.0) as client:
response = client.post(
f"http://{self.context.host}:{self.context.port}/workers",
Expand All @@ -209,8 +217,9 @@ def _add_worker_to_router(self, worker_url: str) -> bool:
response_data = response.json()
if response_data.get("status") == "accepted":
logger.info(
"Worker %s accepted by sglang router on port %s",
worker_url,
"Worker %s (type=%s) accepted by sglang router on port %s",
url,
worker_type,
self.context.port,
)
return True
Expand All @@ -224,21 +233,68 @@ def _add_worker_to_router(self, worker_url: str) -> bool:
else:
logger.error(
"Failed to add worker %s: status %d, %s",
worker_url,
url,
response.status_code,
response.text,
)
return False
except Exception:
logger.exception("Error adding worker %s", worker_url)
logger.exception("Error adding worker %s", url)
return False

def _register_worker(self, url: str) -> bool:
if not self.config.pd_disaggregation:
return self._add_worker_to_router(url, "regular", None)

server_info_url = f"{url}/server_info"
try:
with httpx.Client(timeout=10) as client:
resp = client.get(server_info_url)
if resp.status_code != 200:
return False
data = resp.json()
if data.get("status") != "ready":
return False
disaggregation_mode = data.get("disaggregation_mode", "")
if disaggregation_mode == "prefill":
worker_type = "prefill"
bootstrap_port = data.get("disaggregation_bootstrap_port")
elif disaggregation_mode == "decode":
worker_type = "decode"
bootstrap_port = None
else:
worker_type = "regular"
bootstrap_port = None
logger.info(
"Registering worker %s (type=%s)",
url,
worker_type,
)
return self._add_worker_to_router(
url,
worker_type,
bootstrap_port,
)
except Exception:
logger.exception("Error registering worker %s", url)
return False

def _remove_worker_from_router(self, worker_url: str) -> bool:
try:
encoded_url = urllib.parse.quote(worker_url, safe="")
current_workers = self._get_router_workers()
worker_id = None
for worker in current_workers:
url = worker.get("url")
if url and isinstance(url, str) and url == worker_url:
worker_id = worker.get("id")
if worker_id and isinstance(worker_id, str):
break
if not worker_id:
logger.error("No worker id found for url %s", worker_url)
return False
with httpx.Client(timeout=5.0) as client:
response = client.delete(
f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}"
f"http://{self.context.host}:{self.context.port}/workers/{worker_id}"
Comment thread
jvstme marked this conversation as resolved.
)
if response.status_code == 202:
response_data = response.json()
Expand Down
Loading