Skip to content

Commit 83d9600

Browse files
pass rate based weighted sampler tested with local and multi-node runs
1 parent b69d7e7 commit 83d9600

8 files changed

Lines changed: 687 additions & 8 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
hydra:
2+
searchpath:
3+
- file://verl/trainer/config
4+
5+
defaults:
6+
- ppo_trainer
7+
- _self_
8+
9+
# parameters added to enable PassRateWeightedSampler with DAPO; override parameters in verl/trainer/config/data/legacy_data.yaml
10+
data:
11+
gen_batch_size: ${data.train_batch_size}
12+
dataloader_num_workers: 0 # Recommended to set to 0 when using curriculum learning samplers (e.g., PassRateWeightedSampler) to prevent data caching before batches are reordered.
13+
sampler:
14+
pass_rate_temperature: 1.0 # temperature parameter for PassRateWeightedSampler, controls sharpness of weighting distribution
15+
use_ema: False # whether to use EMA smoothed pass rates for weighting
16+
ema_alpha: 0.1 # alpha parameter for EMA smoothing of pass rates
17+
18+
reward_model:
19+
reward_manager: dapo
20+
overlong_buffer:
21+
enable: False # We try to avoid forgetting to set enable
22+
len: 0
23+
penalty_factor: 0.0
24+
log: False
25+
26+
algorithm:
27+
filter_groups:
28+
_target_: verl.trainer.config.FilterGroupsConfig
29+
enable: False # We try to avoid forgetting to set enable
30+
metric: null # acc / score / seq_reward / seq_final_reward / ...
31+
max_num_gen_batches: 0 # Non-positive values mean no upper limit
32+
33+

recipe/dapo/dapo_ray_trainer.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444
from verl.utils.profiler import marked_timer
4545
from verl.utils.rollout_skip import RolloutSkip
46-
46+
from verl.utils.pass_rate_weighted_sampler import PassRateWeightedSampler
4747

4848
class RayDAPOTrainer(RayPPOTrainer):
4949
"""
@@ -68,12 +68,19 @@ def fit(self):
6868
config=OmegaConf.to_container(self.config, resolve=True),
6969
)
7070

71-
self.global_steps = 0
71+
self.global_steps = 0
7272
self.gen_steps = 0
73-
7473
# load checkpoint before doing anything
7574
self._load_checkpoint()
7675

76+
# Extract pass rate tracker from sampler if using curriculum learning
77+
# The PassRateWeightedSampler owns the tracker internally but we need to manually update it during training
78+
# Currently, we only support PassRateWeightedSampler for curriculum learning
79+
self.pass_rate_tracker = None
80+
self.data_sampler = self.train_dataloader.sampler # train_dataloader is created in `RayPPOTrainer._create_dataloader()` and always has a sampler
81+
if isinstance(self.data_sampler, PassRateWeightedSampler):
82+
self.pass_rate_tracker = self.data_sampler.pass_rate_tracker
83+
7784
# perform validation before training
7885
# currently, we only support validation using the reward_function.
7986
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
@@ -135,7 +142,6 @@ def fit(self):
135142
non_tensor_batch_keys=["raw_prompt_ids"],
136143
)
137144
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
138-
139145
is_last_step = self.global_steps >= self.total_training_steps
140146

141147
with marked_timer("step", timing_raw):
@@ -189,7 +195,6 @@ def fit(self):
189195
reward_extra_infos_dict = {}
190196

191197
new_batch.batch["token_level_scores"] = reward_tensor
192-
193198
if reward_extra_infos_dict:
194199
new_batch.non_tensor_batch.update(
195200
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
@@ -206,6 +211,47 @@ def fit(self):
206211
else:
207212
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
208213

214+
# === Curriculum Learning: Update pass rate tracker for weighted resampling ===
215+
# When using PassRateWeightedSampler, track per-sample success rates to enable dynamic curriculum learning.
216+
# The sampler uses these pass rates to adjust sampling probabilities in the next epoch.
217+
218+
# Note: make updating the pass rate tracker as a utility function later
219+
# 1. if sampler is an instance of PassRateWeightedSampler, self.pass_rate_tracker is not None
220+
# 2. `dataset_index` field is added to the RL datatset to identify samples
221+
if "dataset_index" in new_batch.non_tensor_batch and self.pass_rate_tracker is not None:
222+
dataset_indices = new_batch.non_tensor_batch["dataset_index"]
223+
# Sum token-level rewards to get sequence-level reward
224+
seq_rewards = new_batch.batch["token_level_rewards"].sum(dim=-1).cpu().numpy()
225+
# Success is 1 if sequence reward > 0, else 0
226+
successes = (seq_rewards > 0).astype(float)
227+
228+
# Deduplicate: batch was repeated n times (interleaved), so we need to aggregate
229+
unique_indices, inverse_indices = np.unique(dataset_indices, return_inverse=True)
230+
231+
assert len(unique_indices) > 0, "No unique samples found in batch. Check data pipeline configuration."
232+
# Aggregate successes: take mean across rollouts for each sample
233+
aggregated_successes = np.zeros(len(unique_indices), dtype=float)
234+
for i, _ in enumerate(unique_indices):
235+
mask = inverse_indices == i # boolean array to indicate positions of unique index i
236+
aggregated_successes[i] = np.mean(successes[mask]) # take average success across rollouts for sample i
237+
238+
pass_rates = self.pass_rate_tracker.get_pass_rates()
239+
240+
# Log curriculum metrics BEFORE updating tracker
241+
# Track improvement of hardest samples (across all samples, not just attempted)
242+
metrics['curriculum/hardest_10pct_pass_rate'] = float(np.percentile(pass_rates, 10))
243+
metrics['curriculum/hardest_25pct_pass_rate'] = float(np.percentile(pass_rates, 25))
244+
metrics['curriculum/hardest_50pct_pass_rate'] = float(np.percentile(pass_rates, 50))
245+
metrics['curriculum/hardest_75pct_pass_rate'] = float(np.percentile(pass_rates, 75))
246+
247+
# Batch-level statistics
248+
metrics['curriculum/min_batch_pass_rate'] = float(np.min(aggregated_successes))
249+
metrics['curriculum/mean_batch_pass_rate'] = float(np.mean(aggregated_successes))
250+
metrics['curriculum/effective_batch_size'] = np.sum(aggregated_successes > 0)/len(unique_indices)
251+
252+
# Update tracker with current batch results
253+
self.pass_rate_tracker.update(sample_indices=unique_indices.astype(int), batch_pass_rate=aggregated_successes)
254+
209255
if not self.config.algorithm.filter_groups.enable:
210256
batch = new_batch
211257
else: # NOTE: When prompts after filtering is less than train batch size,
@@ -280,7 +326,6 @@ def fit(self):
280326
# === Updating ===
281327

282328
batch.batch["response_mask"] = compute_response_mask(batch)
283-
284329
# Balance the number of valid tokens across DP ranks.
285330
# NOTE: This usually changes the order of data in the `batch`,
286331
# which won't affect the advantage calculation (since it's based on uid),
@@ -342,6 +387,7 @@ def fit(self):
342387
actor_output = self.actor_rollout_wg.update_actor(batch)
343388
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
344389
metrics.update(actor_output_metrics)
390+
print("in critic warmup loop")
345391

346392
# Log rollout generations if enabled
347393
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
@@ -430,6 +476,31 @@ def _to_sequence(value):
430476
num_total_prompts = 0
431477
num_gen_batches = 0
432478

479+
# Add curriculum learning metrics to W&B
480+
if isinstance(self.data_sampler, PassRateWeightedSampler):
481+
# Add 3D plot data for weight and count distributions (percentile-based)
482+
try:
483+
import wandb
484+
import pandas as pd
485+
486+
weight_3d_data = self.data_sampler.get_wandb_3d_plot_data(metric_type='weight')
487+
count_3d_data = self.data_sampler.get_wandb_3d_plot_data(metric_type='count')
488+
489+
# Add step to each data point for 3D visualization
490+
for point in weight_3d_data:
491+
point['step'] = self.global_steps
492+
for point in count_3d_data:
493+
point['step'] = self.global_steps
494+
495+
metrics['curriculum/weight_distribution_3d'] = wandb.Table(
496+
dataframe=pd.DataFrame(weight_3d_data)
497+
) if weight_3d_data else None
498+
metrics['curriculum/count_distribution_3d'] = wandb.Table(
499+
dataframe=pd.DataFrame(count_3d_data)
500+
) if count_3d_data else None
501+
except ImportError:
502+
pass # wandb or pandas not available
503+
433504
# TODO: make a canonical logger that supports various backend
434505
logger.log(data=metrics, step=self.global_steps)
435506

0 commit comments

Comments
 (0)