Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9bbce87
fix: warn when engine_args in OpenAIServerConfig are ignored
JRMeyer Nov 29, 2025
098a42d
fix: preserve _internal_config during Pydantic deserialization
JRMeyer Nov 29, 2025
e456b20
fix: strip logprobs from trajectories before sending to RULER judge
JRMeyer Dec 1, 2025
04857dc
feat: extract logprobs from dict messages for GRPO importance sampling
JRMeyer Dec 1, 2025
3d79de3
style: fix import sorting in ruler.py
JRMeyer Dec 1, 2025
61a1526
feat: add importance sampling observability metrics
JRMeyer Dec 2, 2025
0a64c7c
chore: remove design doc from PR
JRMeyer Dec 2, 2025
01edc64
fix: handle dict message logprobs with explicit token_ids
JRMeyer Dec 5, 2025
fc2569b
fix: Handle Qwen3 chat template continue_final_message incompatibility
JRMeyer Dec 6, 2025
127e1a8
Add debug logging to tokenize_trajectory for logprobs investigation
JRMeyer Dec 7, 2025
32d733e
fix: add missing server deps and fix sky.exec task setup
JRMeyer Dec 9, 2025
afadf6e
debug: add reward logging to _train_model
JRMeyer Dec 9, 2025
c208c91
fix: detect logprobs in dict messages for tokenization
JRMeyer Dec 9, 2025
d784654
debug: add extensive tokenization logging for GRPO troubleshooting
JRMeyer Dec 10, 2025
278062b
gitignore: exclude .art-backup/ checkpoint directory
JRMeyer Dec 10, 2025
669d6be
fix: preserve logprobs/token_ids in Pydantic models
JRMeyer Dec 10, 2025
2e35a9c
fix: don't create new History in tokenize_trajectory_groups
JRMeyer Dec 10, 2025
0273317
debug: trace logprobs through tokenize pipeline with object IDs
JRMeyer Dec 10, 2025
8145580
fix: copy dict in trajectory_logging to avoid mutating original (pres…
JRMeyer Dec 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__/
.art/
.art-backup/
.env
.venv/
grpo_trainer_lora_model/
Expand Down
2 changes: 2 additions & 0 deletions .skyignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__pycache__/
.art/
.art-backup/
*.safetensors
# .env
.venv/
grpo_trainer_lora_model/
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ dependencies = [
"typer>=0.15.2",
"litellm==1.74.1",
"weave>=0.51.51",
"uvicorn[standard]",
"fastapi",
]

[project.optional-dependencies]
Expand Down
18 changes: 18 additions & 0 deletions src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import warnings
from typing import Literal

from typing_extensions import TypedDict

from .engine import EngineArgs

ENGINE_INIT_ONLY_ARGS = {
"max_logprobs",
"gpu_memory_utilization",
"tensor_parallel_size",
"max_model_len",
}


def get_openai_server_config(
model_name: str,
Expand Down Expand Up @@ -35,6 +43,16 @@ def get_openai_server_config(
generation_config="vllm",
)
engine_args.update(config.get("engine_args", {}))
user_engine_args = config.get("engine_args", {})
ignored_args = set(user_engine_args.keys()) & ENGINE_INIT_ONLY_ARGS
if ignored_args:
warnings.warn(
f"OpenAIServerConfig.engine_args contains {ignored_args} which will be "
f"ignored. The vLLM engine is initialized by Unsloth before this config "
f"is applied. Use TrainableModel._internal_config.engine_args instead.",
UserWarning,
stacklevel=2,
)
return OpenAIServerConfig(
log_file=log_file, server_args=server_args, engine_args=engine_args
)
Expand Down
29 changes: 29 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ def _get_packed_tensors(
except Exception:
self._image_processors[model.base_model] = None
tokenizer = self._tokenizers[model.base_model]
print("[DEBUG _get_packed_tensors] Right before tokenize_trajectory_groups")
for tg_idx, tg in enumerate(trajectory_groups):
for traj_idx, traj in enumerate(tg.trajectories):
print(f"[DEBUG _get_packed_tensors] tg={tg_idx} traj={traj_idx} id={id(traj)}")
print(f"[DEBUG _get_packed_tensors] messages_and_choices id={id(traj.messages_and_choices)}")
for msg_idx, msg in enumerate(traj.messages_and_choices):
if isinstance(msg, dict) and msg.get("role") == "assistant":
print(f"[DEBUG _get_packed_tensors] msg {msg_idx} id={id(msg)} keys={list(msg.keys())}")
print(f"[DEBUG _get_packed_tensors] msg {msg_idx} has logprobs: {'logprobs' in msg and bool(msg.get('logprobs'))}")
tokenized_results = list(
tokenize_trajectory_groups(
tokenizer,
Expand Down Expand Up @@ -412,6 +421,26 @@ async def _train_model(
dev_config: dev.TrainConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
print("[DEBUG _train_model] Received trajectory_groups")
for tg_idx, tg in enumerate(trajectory_groups):
rewards = [t.reward for t in tg.trajectories]
print(f"[DEBUG _train_model] tg={tg_idx} rewards={rewards}")
for traj_idx, traj in enumerate(tg.trajectories):
for msg_idx, msg in enumerate(traj.messages_and_choices):
if isinstance(msg, dict) and msg.get("role") == "assistant":
print(f"[DEBUG _train_model] tg={tg_idx} traj={traj_idx} msg={msg_idx}")
print(f"[DEBUG _train_model] Assistant msg keys: {list(msg.keys())}")
print(f"[DEBUG _train_model] has logprobs: {'logprobs' in msg}")
if 'logprobs' in msg:
lp = msg['logprobs']
print(f"[DEBUG _train_model] logprobs type: {type(lp)}, truthy: {bool(lp)}")
if isinstance(lp, dict):
print(f"[DEBUG _train_model] logprobs keys: {list(lp.keys())}")
if 'values' in lp:
print(f"[DEBUG _train_model] logprobs['values'] len: {len(lp['values'])}")
print(f"[DEBUG _train_model] token_ids present: {'token_ids' in msg and msg.get('token_ids') is not None}")
if 'token_ids' in msg and msg.get('token_ids') is not None:
print(f"[DEBUG _train_model] token_ids len: {len(msg['token_ids'])}")
if verbose:
print("Starting _train_model")
service = await self._get_service(model)
Expand Down
21 changes: 19 additions & 2 deletions src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Loss(BaseModel):
mean_kl: torch.Tensor
mean_entropy: torch.Tensor | None
probs_corr: torch.Tensor
frac_old_logprobs_valid: float
mean_importance_ratio: torch.Tensor
clip_fraction: torch.Tensor


def loss_fn(
Expand All @@ -32,6 +35,9 @@ def loss_fn(
)
weights = shift_tensor(inputs["weights"], 0.0)
old_logprobs_mask = ~torch.isnan(old_logprobs)
frac_old_logprobs_valid = (
old_logprobs_mask.float().sum() / (old_logprobs.numel() + 1e-6)
).item()
probs_corr = torch.corrcoef(
torch.stack(
[
Expand Down Expand Up @@ -77,15 +83,23 @@ def loss_fn(
)
if tau := experimental_config.get("kimi_k2_tau", None):
advantages -= tau * logprob_diff.detach()
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
if experimental_config.get("ppo", True):
policy_loss = -torch.min(
prob_ratio * advantages,
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
clipped_ratio * advantages,
)
else:
# Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO)
policy_loss = -(
torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high)
clipped_ratio.detach()
* advantages
* new_logprobs
)
Expand Down Expand Up @@ -123,6 +137,9 @@ def loss_fn(
mean_kl=mean_kl,
mean_entropy=mean_entropy,
probs_corr=probs_corr,
frac_old_logprobs_valid=frac_old_logprobs_valid,
mean_importance_ratio=mean_importance_ratio,
clip_fraction=clip_fraction,
)


Expand Down
26 changes: 24 additions & 2 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Optional,
TypeVar,
cast,
overload,
)

import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from typing_extensions import Never

from . import dev
Expand Down Expand Up @@ -279,6 +288,19 @@ def __init__(
# Bypass BaseModel __setattr__ to allow setting private attr
object.__setattr__(self, "_internal_config", _internal_config)

@model_validator(mode="wrap")
@classmethod
def _preserve_internal_config(
cls, data: Any, handler: Any
) -> "TrainableModel[ModelConfig]":
internal_config = None
if isinstance(data, dict) and "_internal_config" in data:
internal_config = data.pop("_internal_config")
model = handler(data)
if internal_config is not None:
object.__setattr__(model, "_internal_config", internal_config)
return model

@overload
def __new__(
cls,
Expand Down
140 changes: 106 additions & 34 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,37 @@ def tokenize_trajectory_groups(
shuffle_group_trajectories: bool = True,
image_processor: BaseImageProcessor | None = None,
) -> Generator["TokenizedResult", None, None]:
for group in trajectory_groups:
print(f"[TOKENIZE_GROUPS] Starting with {len(trajectory_groups)} groups")
for group_idx, group in enumerate(trajectory_groups):
if not group:
continue
print(f"[TOKENIZE_GROUPS] Group {group_idx}: {len(group)} trajectories")
results: list[TokenizedResult] = []
# Calculate GRPO group mean and standard deviation
reward_mean = sum(trajectory.reward for trajectory in group) / len(group)
reward_std = math.sqrt(
sum((trajectory.reward - reward_mean) ** 2 for trajectory in group)
/ len(group)
)
for trajectory in group:
print(f"[TOKENIZE_GROUPS] Group {group_idx}: rewards={[t.reward for t in group]}, mean={reward_mean}, std={reward_std}")
for traj_idx, trajectory in enumerate(group):
# Calculate GRPO advantage for this trajectory
advantage = trajectory.reward - reward_mean
if scale_rewards:
advantage /= reward_std + 1e-6
print(f"[TOKENIZE_GROUPS] Group {group_idx} Traj {traj_idx}: raw_adv={trajectory.reward - reward_mean}, scaled_adv={advantage}")
# Skip trajectories with no advantage
if advantage == 0:
print(f"[TOKENIZE_GROUPS] Group {group_idx} Traj {traj_idx}: SKIPPED (advantage=0)")
continue
trajectory_results: list[TokenizedResult] = []
for history in [
History(
messages_and_choices=trajectory.messages_and_choices,
tools=trajectory.tools,
),
*trajectory.additional_histories,
]:
print(f"[TOKENIZE_GROUPS] About to iterate. trajectory type: {type(trajectory).__name__}")
print(f"[TOKENIZE_GROUPS] trajectory.messages_and_choices has {len(trajectory.messages_and_choices)} items")
for msg_idx, msg in enumerate(trajectory.messages_and_choices):
if isinstance(msg, dict) and msg.get("role") == "assistant":
print(f"[TOKENIZE_GROUPS] traj msg {msg_idx} keys: {list(msg.keys())}")
print(f"[TOKENIZE_GROUPS] traj msg {msg_idx} has logprobs: {'logprobs' in msg and bool(msg.get('logprobs'))}")
for history in [trajectory, *trajectory.additional_histories]:
if result := tokenize_trajectory(
tokenizer,
image_processor,
Expand Down Expand Up @@ -136,21 +141,39 @@ def tokenize_trajectory(
"""
Tokenizes a trajectory and returns a TokenizedResult.
"""
print(f"[TOKENIZE_TRAJ] Received history type: {type(history).__name__}")
print(f"[TOKENIZE_TRAJ] history id: {id(history)}")
print(f"[TOKENIZE_TRAJ] messages_and_choices id: {id(history.messages_and_choices)}")
if len(history.messages_and_choices) > 0:
for msg_idx, msg in enumerate(history.messages_and_choices):
if isinstance(msg, dict) and msg.get("role") == "assistant":
print(f"[TOKENIZE_TRAJ] msg {msg_idx} id: {id(msg)}")
print(f"[TOKENIZE_TRAJ] msg {msg_idx} keys: {list(msg.keys())}")
print(f"[TOKENIZE_TRAJ] msg {msg_idx} has logprobs: {'logprobs' in msg}")
if 'logprobs' in msg:
print(f"[TOKENIZE_TRAJ] msg {msg_idx} logprobs truthy: {bool(msg['logprobs'])}")
# Find the index of the last assistant message
last_assistant_index = -1
print(f"[TOKENIZE FIRST LOOP] Checking {len(history.messages_and_choices)} messages")
for i, message in enumerate(history.messages_and_choices):
if isinstance(message, dict):
print(f"[TOKENIZE FIRST LOOP] msg {i}: dict, role={message.get('role')}, has_logprobs={bool(message.get('logprobs'))}")
else:
print(f"[TOKENIZE FIRST LOOP] msg {i}: Choice obj, has_logprobs={bool(message.logprobs if hasattr(message, 'logprobs') else None)}")
if (
isinstance(message, dict)
and message["role"] == "assistant"
and allow_training_without_logprobs
and (message.get("logprobs") or allow_training_without_logprobs)
):
last_assistant_index = i
elif not isinstance(message, dict) and (
message.logprobs or allow_training_without_logprobs
):
last_assistant_index = i
print(f"[TOKENIZE FIRST LOOP] last_assistant_index={last_assistant_index}")
# If there are no trainable assistant messages, return None
if last_assistant_index == -1:
print("[TOKENIZE FIRST LOOP] -> Returning None (no trainable messages)")
return None
messages_and_choices = history.messages_and_choices[: last_assistant_index + 1]
messages = get_messages(messages_and_choices)
Expand All @@ -159,23 +182,28 @@ def tokenize_trajectory(
if history.tools is not None
else None
)
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
try:
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
except ValueError as e:
if "continue_final_message" in str(e):
return None
raise
sentinal_token_id = max(
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
)
Expand Down Expand Up @@ -216,13 +244,57 @@ def tokenize_trajectory(
if isinstance(message, dict):
content = message.get("content")
assert isinstance(content, str)
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
msg_token_ids = message.get("token_ids")
dict_logprobs = message.get("logprobs")
print(f"[TOKENIZE DEBUG] Processing assistant dict message:")
print(f" message keys: {list(message.keys())}")
print(f" msg_token_ids is not None: {msg_token_ids is not None}")
print(f" dict_logprobs truthy: {bool(dict_logprobs)}")
print(f" dict_logprobs value: {repr(dict_logprobs)[:200] if dict_logprobs else repr(dict_logprobs)}")
if dict_logprobs:
print(f" dict_logprobs type: {type(dict_logprobs).__name__}")
print(f" dict_logprobs keys: {list(dict_logprobs.keys()) if isinstance(dict_logprobs, dict) else 'N/A'}")
print(f" 'values' in dict_logprobs: {'values' in dict_logprobs if isinstance(dict_logprobs, dict) else 'N/A'}")
if (
msg_token_ids is not None
and dict_logprobs
and "values" in dict_logprobs
):
print(f" -> Using provided token_ids ({len(msg_token_ids)}) and logprobs.values ({len(dict_logprobs['values'])})")
token_ids[start:end] = msg_token_ids
logprobs[start:end] = dict_logprobs["values"]
assistant_mask[start:end] = [1] * len(msg_token_ids)
elif (
dict_logprobs
and "content" in dict_logprobs
and dict_logprobs["content"]
):
token_logprobs = dict_logprobs["content"]
try:
token_ids[start:end] = [
int(lp["token"].split(":")[1]) for lp in token_logprobs
]
except (IndexError, ValueError, KeyError):
token_ids[start:end] = [
token_id if token_id is not None else tokenizer.eos_token_id
for token_id in tokenizer.convert_tokens_to_ids(
[
lp.get("token") or tokenizer.eos_token
for lp in token_logprobs
]
)
]
logprobs[start:end] = [lp["logprob"] for lp in token_logprobs]
assistant_mask[start:end] = [1] * len(token_logprobs)
else:
print(f" -> FALLBACK: re-tokenizing content, logprobs will be NaN")
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
else:
choice = message
assert choice.logprobs or allow_training_without_logprobs, (
Expand Down
Loading
Loading