diff --git a/.gitignore b/.gitignore index 748d300f..58cd3413 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .art/ +.art-backup/ .env .venv/ grpo_trainer_lora_model/ diff --git a/.skyignore b/.skyignore index d7afeb0b..acbf0fae 100644 --- a/.skyignore +++ b/.skyignore @@ -1,5 +1,7 @@ __pycache__/ .art/ +.art-backup/ +*.safetensors # .env .venv/ grpo_trainer_lora_model/ diff --git a/pyproject.toml b/pyproject.toml index 077cb4ca..4e16d4e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ dependencies = [ "typer>=0.15.2", "litellm==1.74.1", "weave>=0.51.51", + "uvicorn[standard]", + "fastapi", ] [project.optional-dependencies] diff --git a/src/art/dev/openai_server.py b/src/art/dev/openai_server.py index f6639bda..1da49661 100644 --- a/src/art/dev/openai_server.py +++ b/src/art/dev/openai_server.py @@ -1,9 +1,17 @@ +import warnings from typing import Literal from typing_extensions import TypedDict from .engine import EngineArgs +ENGINE_INIT_ONLY_ARGS = { + "max_logprobs", + "gpu_memory_utilization", + "tensor_parallel_size", + "max_model_len", +} + def get_openai_server_config( model_name: str, @@ -35,6 +43,16 @@ def get_openai_server_config( generation_config="vllm", ) engine_args.update(config.get("engine_args", {})) + user_engine_args = config.get("engine_args", {}) + ignored_args = set(user_engine_args.keys()) & ENGINE_INIT_ONLY_ARGS + if ignored_args: + warnings.warn( + f"OpenAIServerConfig.engine_args contains {ignored_args} which will be " + f"ignored. The vLLM engine is initialized by Unsloth before this config " + f"is applied. Use TrainableModel._internal_config.engine_args instead.", + UserWarning, + stacklevel=2, + ) return OpenAIServerConfig( log_file=log_file, server_args=server_args, engine_args=engine_args ) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 5a0bc619..bb5f2128 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -194,6 +194,15 @@ def _get_packed_tensors( except Exception: self._image_processors[model.base_model] = None tokenizer = self._tokenizers[model.base_model] + print("[DEBUG _get_packed_tensors] Right before tokenize_trajectory_groups") + for tg_idx, tg in enumerate(trajectory_groups): + for traj_idx, traj in enumerate(tg.trajectories): + print(f"[DEBUG _get_packed_tensors] tg={tg_idx} traj={traj_idx} id={id(traj)}") + print(f"[DEBUG _get_packed_tensors] messages_and_choices id={id(traj.messages_and_choices)}") + for msg_idx, msg in enumerate(traj.messages_and_choices): + if isinstance(msg, dict) and msg.get("role") == "assistant": + print(f"[DEBUG _get_packed_tensors] msg {msg_idx} id={id(msg)} keys={list(msg.keys())}") + print(f"[DEBUG _get_packed_tensors] msg {msg_idx} has logprobs: {'logprobs' in msg and bool(msg.get('logprobs'))}") tokenized_results = list( tokenize_trajectory_groups( tokenizer, @@ -412,6 +421,26 @@ async def _train_model( dev_config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: + print("[DEBUG _train_model] Received trajectory_groups") + for tg_idx, tg in enumerate(trajectory_groups): + rewards = [t.reward for t in tg.trajectories] + print(f"[DEBUG _train_model] tg={tg_idx} rewards={rewards}") + for traj_idx, traj in enumerate(tg.trajectories): + for msg_idx, msg in enumerate(traj.messages_and_choices): + if isinstance(msg, dict) and msg.get("role") == "assistant": + print(f"[DEBUG _train_model] tg={tg_idx} traj={traj_idx} msg={msg_idx}") + print(f"[DEBUG _train_model] Assistant msg keys: {list(msg.keys())}") + print(f"[DEBUG _train_model] has logprobs: {'logprobs' in msg}") + if 'logprobs' in msg: + lp = msg['logprobs'] + print(f"[DEBUG _train_model] logprobs type: {type(lp)}, truthy: {bool(lp)}") + if isinstance(lp, dict): + print(f"[DEBUG _train_model] logprobs keys: {list(lp.keys())}") + if 'values' in lp: + print(f"[DEBUG _train_model] logprobs['values'] len: {len(lp['values'])}") + print(f"[DEBUG _train_model] token_ids present: {'token_ids' in msg and msg.get('token_ids') is not None}") + if 'token_ids' in msg and msg.get('token_ids') is not None: + print(f"[DEBUG _train_model] token_ids len: {len(msg['token_ids'])}") if verbose: print("Starting _train_model") service = await self._get_service(model) diff --git a/src/art/loss.py b/src/art/loss.py index f209cc6e..246f0356 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -16,6 +16,9 @@ class Loss(BaseModel): mean_kl: torch.Tensor mean_entropy: torch.Tensor | None probs_corr: torch.Tensor + frac_old_logprobs_valid: float + mean_importance_ratio: torch.Tensor + clip_fraction: torch.Tensor def loss_fn( @@ -32,6 +35,9 @@ def loss_fn( ) weights = shift_tensor(inputs["weights"], 0.0) old_logprobs_mask = ~torch.isnan(old_logprobs) + frac_old_logprobs_valid = ( + old_logprobs_mask.float().sum() / (old_logprobs.numel() + 1e-6) + ).item() probs_corr = torch.corrcoef( torch.stack( [ @@ -77,15 +83,23 @@ def loss_fn( ) if tau := experimental_config.get("kimi_k2_tau", None): advantages -= tau * logprob_diff.detach() + clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) + is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high) + clip_fraction = (is_clipped.float() * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + mean_importance_ratio = (prob_ratio * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) if experimental_config.get("ppo", True): policy_loss = -torch.min( prob_ratio * advantages, - torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages, + clipped_ratio * advantages, ) else: # Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO) policy_loss = -( - torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high) + clipped_ratio.detach() * advantages * new_logprobs ) @@ -123,6 +137,9 @@ def loss_fn( mean_kl=mean_kl, mean_entropy=mean_entropy, probs_corr=probs_corr, + frac_old_logprobs_valid=frac_old_logprobs_valid, + mean_importance_ratio=mean_importance_ratio, + clip_fraction=clip_fraction, ) diff --git a/src/art/model.py b/src/art/model.py index 43c519b2..040f157d 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -1,8 +1,17 @@ -from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Optional, + TypeVar, + cast, + overload, +) import httpx from openai import AsyncOpenAI, DefaultAsyncHttpxClient -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from typing_extensions import Never from . import dev @@ -279,6 +288,19 @@ def __init__( # Bypass BaseModel __setattr__ to allow setting private attr object.__setattr__(self, "_internal_config", _internal_config) + @model_validator(mode="wrap") + @classmethod + def _preserve_internal_config( + cls, data: Any, handler: Any + ) -> "TrainableModel[ModelConfig]": + internal_config = None + if isinstance(data, dict) and "_internal_config" in data: + internal_config = data.pop("_internal_config") + model = handler(data) + if internal_config is not None: + object.__setattr__(model, "_internal_config", internal_config) + return model + @overload def __new__( cls, diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 6ffc462f..52842868 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -52,9 +52,11 @@ def tokenize_trajectory_groups( shuffle_group_trajectories: bool = True, image_processor: BaseImageProcessor | None = None, ) -> Generator["TokenizedResult", None, None]: - for group in trajectory_groups: + print(f"[TOKENIZE_GROUPS] Starting with {len(trajectory_groups)} groups") + for group_idx, group in enumerate(trajectory_groups): if not group: continue + print(f"[TOKENIZE_GROUPS] Group {group_idx}: {len(group)} trajectories") results: list[TokenizedResult] = [] # Calculate GRPO group mean and standard deviation reward_mean = sum(trajectory.reward for trajectory in group) / len(group) @@ -62,22 +64,25 @@ def tokenize_trajectory_groups( sum((trajectory.reward - reward_mean) ** 2 for trajectory in group) / len(group) ) - for trajectory in group: + print(f"[TOKENIZE_GROUPS] Group {group_idx}: rewards={[t.reward for t in group]}, mean={reward_mean}, std={reward_std}") + for traj_idx, trajectory in enumerate(group): # Calculate GRPO advantage for this trajectory advantage = trajectory.reward - reward_mean if scale_rewards: advantage /= reward_std + 1e-6 + print(f"[TOKENIZE_GROUPS] Group {group_idx} Traj {traj_idx}: raw_adv={trajectory.reward - reward_mean}, scaled_adv={advantage}") # Skip trajectories with no advantage if advantage == 0: + print(f"[TOKENIZE_GROUPS] Group {group_idx} Traj {traj_idx}: SKIPPED (advantage=0)") continue trajectory_results: list[TokenizedResult] = [] - for history in [ - History( - messages_and_choices=trajectory.messages_and_choices, - tools=trajectory.tools, - ), - *trajectory.additional_histories, - ]: + print(f"[TOKENIZE_GROUPS] About to iterate. trajectory type: {type(trajectory).__name__}") + print(f"[TOKENIZE_GROUPS] trajectory.messages_and_choices has {len(trajectory.messages_and_choices)} items") + for msg_idx, msg in enumerate(trajectory.messages_and_choices): + if isinstance(msg, dict) and msg.get("role") == "assistant": + print(f"[TOKENIZE_GROUPS] traj msg {msg_idx} keys: {list(msg.keys())}") + print(f"[TOKENIZE_GROUPS] traj msg {msg_idx} has logprobs: {'logprobs' in msg and bool(msg.get('logprobs'))}") + for history in [trajectory, *trajectory.additional_histories]: if result := tokenize_trajectory( tokenizer, image_processor, @@ -136,21 +141,39 @@ def tokenize_trajectory( """ Tokenizes a trajectory and returns a TokenizedResult. """ + print(f"[TOKENIZE_TRAJ] Received history type: {type(history).__name__}") + print(f"[TOKENIZE_TRAJ] history id: {id(history)}") + print(f"[TOKENIZE_TRAJ] messages_and_choices id: {id(history.messages_and_choices)}") + if len(history.messages_and_choices) > 0: + for msg_idx, msg in enumerate(history.messages_and_choices): + if isinstance(msg, dict) and msg.get("role") == "assistant": + print(f"[TOKENIZE_TRAJ] msg {msg_idx} id: {id(msg)}") + print(f"[TOKENIZE_TRAJ] msg {msg_idx} keys: {list(msg.keys())}") + print(f"[TOKENIZE_TRAJ] msg {msg_idx} has logprobs: {'logprobs' in msg}") + if 'logprobs' in msg: + print(f"[TOKENIZE_TRAJ] msg {msg_idx} logprobs truthy: {bool(msg['logprobs'])}") # Find the index of the last assistant message last_assistant_index = -1 + print(f"[TOKENIZE FIRST LOOP] Checking {len(history.messages_and_choices)} messages") for i, message in enumerate(history.messages_and_choices): + if isinstance(message, dict): + print(f"[TOKENIZE FIRST LOOP] msg {i}: dict, role={message.get('role')}, has_logprobs={bool(message.get('logprobs'))}") + else: + print(f"[TOKENIZE FIRST LOOP] msg {i}: Choice obj, has_logprobs={bool(message.logprobs if hasattr(message, 'logprobs') else None)}") if ( isinstance(message, dict) and message["role"] == "assistant" - and allow_training_without_logprobs + and (message.get("logprobs") or allow_training_without_logprobs) ): last_assistant_index = i elif not isinstance(message, dict) and ( message.logprobs or allow_training_without_logprobs ): last_assistant_index = i + print(f"[TOKENIZE FIRST LOOP] last_assistant_index={last_assistant_index}") # If there are no trainable assistant messages, return None if last_assistant_index == -1: + print("[TOKENIZE FIRST LOOP] -> Returning None (no trainable messages)") return None messages_and_choices = history.messages_and_choices[: last_assistant_index + 1] messages = get_messages(messages_and_choices) @@ -159,23 +182,28 @@ def tokenize_trajectory( if history.tools is not None else None ) - chat = cast( - str, - tokenizer.apply_chat_template( - cast(list[dict], messages), - tools=tools, # type: ignore - continue_final_message=True, - tokenize=False, - ), - ) - original_token_ids = cast( - list[int], - tokenizer.apply_chat_template( - cast(list[dict], messages), - tools=tools, # type: ignore - continue_final_message=True, - ), - ) + try: + chat = cast( + str, + tokenizer.apply_chat_template( + cast(list[dict], messages), + tools=tools, # type: ignore + continue_final_message=True, + tokenize=False, + ), + ) + original_token_ids = cast( + list[int], + tokenizer.apply_chat_template( + cast(list[dict], messages), + tools=tools, # type: ignore + continue_final_message=True, + ), + ) + except ValueError as e: + if "continue_final_message" in str(e): + return None + raise sentinal_token_id = max( set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids) ) @@ -216,13 +244,57 @@ def tokenize_trajectory( if isinstance(message, dict): content = message.get("content") assert isinstance(content, str) - content_token_ids = tokenizer.encode( - content, - add_special_tokens=False, - ) - token_ids[start:end] = content_token_ids - logprobs[start:end] = [float("nan")] * len(content_token_ids) - assistant_mask[start:end] = [1] * len(content_token_ids) + msg_token_ids = message.get("token_ids") + dict_logprobs = message.get("logprobs") + print(f"[TOKENIZE DEBUG] Processing assistant dict message:") + print(f" message keys: {list(message.keys())}") + print(f" msg_token_ids is not None: {msg_token_ids is not None}") + print(f" dict_logprobs truthy: {bool(dict_logprobs)}") + print(f" dict_logprobs value: {repr(dict_logprobs)[:200] if dict_logprobs else repr(dict_logprobs)}") + if dict_logprobs: + print(f" dict_logprobs type: {type(dict_logprobs).__name__}") + print(f" dict_logprobs keys: {list(dict_logprobs.keys()) if isinstance(dict_logprobs, dict) else 'N/A'}") + print(f" 'values' in dict_logprobs: {'values' in dict_logprobs if isinstance(dict_logprobs, dict) else 'N/A'}") + if ( + msg_token_ids is not None + and dict_logprobs + and "values" in dict_logprobs + ): + print(f" -> Using provided token_ids ({len(msg_token_ids)}) and logprobs.values ({len(dict_logprobs['values'])})") + token_ids[start:end] = msg_token_ids + logprobs[start:end] = dict_logprobs["values"] + assistant_mask[start:end] = [1] * len(msg_token_ids) + elif ( + dict_logprobs + and "content" in dict_logprobs + and dict_logprobs["content"] + ): + token_logprobs = dict_logprobs["content"] + try: + token_ids[start:end] = [ + int(lp["token"].split(":")[1]) for lp in token_logprobs + ] + except (IndexError, ValueError, KeyError): + token_ids[start:end] = [ + token_id if token_id is not None else tokenizer.eos_token_id + for token_id in tokenizer.convert_tokens_to_ids( + [ + lp.get("token") or tokenizer.eos_token + for lp in token_logprobs + ] + ) + ] + logprobs[start:end] = [lp["logprob"] for lp in token_logprobs] + assistant_mask[start:end] = [1] * len(token_logprobs) + else: + print(f" -> FALLBACK: re-tokenizing content, logprobs will be NaN") + content_token_ids = tokenizer.encode( + content, + add_special_tokens=False, + ) + token_ids[start:end] = content_token_ids + logprobs[start:end] = [float("nan")] * len(content_token_ids) + assistant_mask[start:end] = [1] * len(content_token_ids) else: choice = message assert choice.logprobs or allow_training_without_logprobs, ( diff --git a/src/art/rewards/ruler.py b/src/art/rewards/ruler.py index 2ea33312..41054470 100644 --- a/src/art/rewards/ruler.py +++ b/src/art/rewards/ruler.py @@ -20,6 +20,7 @@ from rich import print import art +from art.utils.strip_logprobs import strip_logprobs class TrajectoryScore(BaseModel): @@ -287,9 +288,10 @@ async def ruler_score_group( new_trajectories.append(new_traj) # Extract message lists and preserve original rewards for comparison + # Strip logprobs to avoid sending huge token probability data to the judge message_lists: list[list[ChatCompletionMessageParam]] = [] for traj in new_trajectories: - message_lists.append(traj.messages()) + message_lists.append(strip_logprobs(traj.messages())) traj.metrics["independent_reward"] = traj.reward try: diff --git a/src/art/skypilot/backend.py b/src/art/skypilot/backend.py index 68bdfb78..667ad03a 100644 --- a/src/art/skypilot/backend.py +++ b/src/art/skypilot/backend.py @@ -104,7 +104,10 @@ async def initialize_cluster( ) print("Art server task already running, using it…") else: - art_server_task = sky.Task(name="art_server", run="uv run art") + art_server_task = sky.Task( + name="art_server", + run="source $HOME/.local/bin/env && uv sync --extra backend && uv run art", + ) clusters = await to_thread_typed( lambda: sky.stream_and_get( diff --git a/src/art/skypilot/utils.py b/src/art/skypilot/utils.py index 238e139e..a4363767 100644 --- a/src/art/skypilot/utils.py +++ b/src/art/skypilot/utils.py @@ -52,7 +52,7 @@ async def wait_for_task_to_start(cluster_name: str, task_name: str) -> None: task_status = await get_task_status(cluster_name, task_name) num_checks = 0 - while num_checks < 12: + while num_checks < 120: task_status = await get_task_status(cluster_name, task_name) if task_status is None: raise ValueError(f"Task {task_name} not found in cluster {cluster_name}") @@ -62,7 +62,7 @@ async def wait_for_task_to_start(cluster_name: str, task_name: str) -> None: num_checks += 1 raise ValueError( - f"Task {task_name} in cluster {cluster_name} failed to start within 60s" + f"Task {task_name} in cluster {cluster_name} failed to start within 600s" ) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 5ff48c94..af977f47 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -28,6 +28,7 @@ class PydanticException(pydantic.BaseModel): class History(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') messages_and_choices: MessagesAndChoices tools: Tools | None = None @@ -36,6 +37,7 @@ def messages(self) -> Messages: class Trajectory(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') messages_and_choices: MessagesAndChoices tools: Tools | None = None additional_histories: list[History] = [] diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index ddacaafd..320b407a 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -167,7 +167,10 @@ def compute_loss( trainer._metrics["train"]["learning_rate"].append(config.learning_rate) trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item()) if loss.mean_entropy is not None: - trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) # type: ignore + trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) + trainer._metrics["train"]["frac_old_logprobs_valid"].append(loss.frac_old_logprobs_valid) + trainer._metrics["train"]["mean_importance_ratio"].append(loss.mean_importance_ratio.item()) + trainer._metrics["train"]["clip_fraction"].append(loss.clip_fraction.item()) if config.beta > 0.0: trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item()) return loss.mean_policy_loss + config.beta * loss.mean_kl # type: ignore diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py index 42b7fdab..b1dc2068 100644 --- a/src/art/utils/trajectory_logging.py +++ b/src/art/utils/trajectory_logging.py @@ -63,14 +63,14 @@ def trajectory_to_dict(trajectory: Trajectory) -> dict[str, Any]: def message_or_choice_to_dict(message_or_choice: MessageOrChoice) -> dict[str, Any]: # messages are sometimes stored as dicts, so we need to handle both cases + # IMPORTANT: Must copy dicts to avoid mutating the original (which strips logprobs needed for training) item_dict = ( - message_or_choice + dict(message_or_choice) if isinstance(message_or_choice, dict) else message_or_choice.to_dict() ) if "logprobs" in item_dict: - # item is a choice with logprobs, remove the logprobs item_dict.pop("logprobs") if "content" in item_dict and isinstance(item_dict["content"], Iterator): diff --git a/uv.lock b/uv.lock index a5c87d7e..de147490 100644 --- a/uv.lock +++ b/uv.lock @@ -4124,9 +4124,11 @@ name = "openpipe-art" version = "0.5.3" source = { editable = "." } dependencies = [ + { name = "fastapi" }, { name = "litellm" }, { name = "openai" }, { name = "typer" }, + { name = "uvicorn", extra = ["standard"] }, { name = "weave" }, ] @@ -4188,6 +4190,7 @@ requires-dist = [ { name = "accelerate", marker = "extra == 'backend'", specifier = "==1.7.0" }, { name = "awscli", marker = "extra == 'backend'", specifier = ">=1.38.1" }, { name = "bitsandbytes", marker = "extra == 'backend'", specifier = ">=0.45.2" }, + { name = "fastapi" }, { name = "gql", marker = "extra == 'backend'", specifier = "<4" }, { name = "hf-xet", marker = "extra == 'backend'", specifier = ">=1.1.0" }, { name = "langchain-core", marker = "extra == 'langgraph'", specifier = ">=0.3.51" }, @@ -4216,6 +4219,7 @@ requires-dist = [ { name = "typer", specifier = ">=0.15.2" }, { name = "unsloth", marker = "extra == 'backend'", specifier = "==2025.10.3" }, { name = "unsloth-zoo", marker = "extra == 'backend'", specifier = "==2025.10.3" }, + { name = "uvicorn", extras = ["standard"] }, { name = "vllm", marker = "extra == 'backend'", specifier = ">=0.9.2,<=0.10.0" }, { name = "wandb", marker = "extra == 'backend'", specifier = "==0.22.1" }, { name = "weave", specifier = ">=0.51.51" },