4343)
4444from verl .utils .profiler import marked_timer
4545from verl .utils .rollout_skip import RolloutSkip
46-
46+ from verl . utils . pass_rate_weighted_sampler import PassRateWeightedSampler
4747
4848class RayDAPOTrainer (RayPPOTrainer ):
4949 """
@@ -68,12 +68,19 @@ def fit(self):
6868 config = OmegaConf .to_container (self .config , resolve = True ),
6969 )
7070
71- self .global_steps = 0
71+ self .global_steps = 0
7272 self .gen_steps = 0
73-
7473 # load checkpoint before doing anything
7574 self ._load_checkpoint ()
7675
76+ # Extract pass rate tracker from sampler if using curriculum learning
77+ # The PassRateWeightedSampler owns the tracker internally but we need to manually update it during training
78+ # Currently, we only support PassRateWeightedSampler for curriculum learning
79+ self .pass_rate_tracker = None
80+ self .data_sampler = self .train_dataloader .sampler # train_dataloader is created in `RayPPOTrainer._create_dataloader()` and always has a sampler
81+ if isinstance (self .data_sampler , PassRateWeightedSampler ):
82+ self .pass_rate_tracker = self .data_sampler .pass_rate_tracker
83+
7784 # perform validation before training
7885 # currently, we only support validation using the reward_function.
7986 if self .val_reward_fn is not None and self .config .trainer .get ("val_before_train" , True ):
@@ -135,7 +142,6 @@ def fit(self):
135142 non_tensor_batch_keys = ["raw_prompt_ids" ],
136143 )
137144 gen_batch = gen_batch .repeat (repeat_times = self .config .actor_rollout_ref .rollout .n , interleave = True )
138-
139145 is_last_step = self .global_steps >= self .total_training_steps
140146
141147 with marked_timer ("step" , timing_raw ):
@@ -189,7 +195,6 @@ def fit(self):
189195 reward_extra_infos_dict = {}
190196
191197 new_batch .batch ["token_level_scores" ] = reward_tensor
192-
193198 if reward_extra_infos_dict :
194199 new_batch .non_tensor_batch .update (
195200 {k : np .array (v ) for k , v in reward_extra_infos_dict .items ()}
@@ -206,6 +211,47 @@ def fit(self):
206211 else :
207212 new_batch .batch ["token_level_rewards" ] = new_batch .batch ["token_level_scores" ]
208213
214+ # === Curriculum Learning: Update pass rate tracker for weighted resampling ===
215+ # When using PassRateWeightedSampler, track per-sample success rates to enable dynamic curriculum learning.
216+ # The sampler uses these pass rates to adjust sampling probabilities in the next epoch.
217+
218+ # Note: make updating the pass rate tracker as a utility function later
219+ # 1. if sampler is an instance of PassRateWeightedSampler, self.pass_rate_tracker is not None
220+ # 2. `dataset_index` field is added to the RL datatset to identify samples
221+ if "dataset_index" in new_batch .non_tensor_batch and self .pass_rate_tracker is not None :
222+ dataset_indices = new_batch .non_tensor_batch ["dataset_index" ]
223+ # Sum token-level rewards to get sequence-level reward
224+ seq_rewards = new_batch .batch ["token_level_rewards" ].sum (dim = - 1 ).cpu ().numpy ()
225+ # Success is 1 if sequence reward > 0, else 0
226+ successes = (seq_rewards > 0 ).astype (float )
227+
228+ # Deduplicate: batch was repeated n times (interleaved), so we need to aggregate
229+ unique_indices , inverse_indices = np .unique (dataset_indices , return_inverse = True )
230+
231+ assert len (unique_indices ) > 0 , "No unique samples found in batch. Check data pipeline configuration."
232+ # Aggregate successes: take mean across rollouts for each sample
233+ aggregated_successes = np .zeros (len (unique_indices ), dtype = float )
234+ for i , _ in enumerate (unique_indices ):
235+ mask = inverse_indices == i # boolean array to indicate positions of unique index i
236+ aggregated_successes [i ] = np .mean (successes [mask ]) # take average success across rollouts for sample i
237+
238+ pass_rates = self .pass_rate_tracker .get_pass_rates ()
239+
240+ # Log curriculum metrics BEFORE updating tracker
241+ # Track improvement of hardest samples (across all samples, not just attempted)
242+ metrics ['curriculum/hardest_10pct_pass_rate' ] = float (np .percentile (pass_rates , 10 ))
243+ metrics ['curriculum/hardest_25pct_pass_rate' ] = float (np .percentile (pass_rates , 25 ))
244+ metrics ['curriculum/hardest_50pct_pass_rate' ] = float (np .percentile (pass_rates , 50 ))
245+ metrics ['curriculum/hardest_75pct_pass_rate' ] = float (np .percentile (pass_rates , 75 ))
246+
247+ # Batch-level statistics
248+ metrics ['curriculum/min_batch_pass_rate' ] = float (np .min (aggregated_successes ))
249+ metrics ['curriculum/mean_batch_pass_rate' ] = float (np .mean (aggregated_successes ))
250+ metrics ['curriculum/effective_batch_size' ] = np .sum (aggregated_successes > 0 )/ len (unique_indices )
251+
252+ # Update tracker with current batch results
253+ self .pass_rate_tracker .update (sample_indices = unique_indices .astype (int ), batch_pass_rate = aggregated_successes )
254+
209255 if not self .config .algorithm .filter_groups .enable :
210256 batch = new_batch
211257 else : # NOTE: When prompts after filtering is less than train batch size,
@@ -280,7 +326,6 @@ def fit(self):
280326 # === Updating ===
281327
282328 batch .batch ["response_mask" ] = compute_response_mask (batch )
283-
284329 # Balance the number of valid tokens across DP ranks.
285330 # NOTE: This usually changes the order of data in the `batch`,
286331 # which won't affect the advantage calculation (since it's based on uid),
@@ -342,6 +387,7 @@ def fit(self):
342387 actor_output = self .actor_rollout_wg .update_actor (batch )
343388 actor_output_metrics = reduce_metrics (actor_output .meta_info ["metrics" ])
344389 metrics .update (actor_output_metrics )
390+ print ("in critic warmup loop" )
345391
346392 # Log rollout generations if enabled
347393 rollout_data_dir = self .config .trainer .get ("rollout_data_dir" , None )
@@ -430,6 +476,31 @@ def _to_sequence(value):
430476 num_total_prompts = 0
431477 num_gen_batches = 0
432478
479+ # Add curriculum learning metrics to W&B
480+ if isinstance (self .data_sampler , PassRateWeightedSampler ):
481+ # Add 3D plot data for weight and count distributions (percentile-based)
482+ try :
483+ import wandb
484+ import pandas as pd
485+
486+ weight_3d_data = self .data_sampler .get_wandb_3d_plot_data (metric_type = 'weight' )
487+ count_3d_data = self .data_sampler .get_wandb_3d_plot_data (metric_type = 'count' )
488+
489+ # Add step to each data point for 3D visualization
490+ for point in weight_3d_data :
491+ point ['step' ] = self .global_steps
492+ for point in count_3d_data :
493+ point ['step' ] = self .global_steps
494+
495+ metrics ['curriculum/weight_distribution_3d' ] = wandb .Table (
496+ dataframe = pd .DataFrame (weight_3d_data )
497+ ) if weight_3d_data else None
498+ metrics ['curriculum/count_distribution_3d' ] = wandb .Table (
499+ dataframe = pd .DataFrame (count_3d_data )
500+ ) if count_3d_data else None
501+ except ImportError :
502+ pass # wandb or pandas not available
503+
433504 # TODO: make a canonical logger that supports various backend
434505 logger .log (data = metrics , step = self .global_steps )
435506
0 commit comments