Stanford CS336 Language Modeling from Scratch | Spring 2025 | Lecture 16: Alignment - RL

By Unknown Author

AITechnologyEducation
Share:

Key Concepts

  • RLHF (Reinforcement Learning from Human Feedback): Training language models using human preferences as a reward signal.
  • DPO (Direct Preference Optimization): An algorithm for optimizing language models based on pairwise preference data, framed as a supervised learning problem.
  • SPIO: A DPO variant that normalizes update size by response length and removes the reference policy.
  • Overoptimization: A phenomenon where optimizing a policy too much on a proxy reward leads to divergence from true human preferences.
  • Calibration: The extent to which a model's predicted probabilities reflect the true likelihood of events. RLHF models often exhibit poor calibration.
  • Verifiable Rewards: Rewards that can be automatically and reliably evaluated, enabling the application of traditional RL techniques to language models.
  • PO (Proximal Policy Optimization): A policy gradient method that uses clipping to constrain policy updates and a value function for variance reduction.
  • GRPO (Generalized Robust Policy Optimization): A simplified RL algorithm that replaces the advantage estimator in PO with a z-score of rewards within a group.
  • Baseline: A constant or random variable subtracted from the reward to reduce variance in policy gradient estimation.
  • Coot (Chain of Thought): A sequence of intermediate reasoning steps generated by a language model to solve a problem.
  • PRM (Process Reward Model): A model that provides intermediate rewards for each step in a chain of thought.
  • MCTS (Monte Carlo Tree Search): A search algorithm used to explore possible actions and their consequences.
  • Thinking Mode Fusion: A technique to combine thinking and non-thinking models into a single model, allowing for controlled inference.

RLHF Recap and Limitations

  • Objective: Maximize an underlying reward based on pairwise preference data.
  • DPO: Rewrites the reward as a ratio of policies, enabling optimization via supervised learning.
    • Updates increase the likelihood of good examples and decrease the likelihood of bad examples.
    • Equation: Gradient steps are multiplied by beta (regularization), with higher weight when reward estimates are wrong.
  • SPIO: Normalizes update size by response length and removes the reference policy.
  • Empirical Findings: RL findings are highly contingent on the specific setting (environment, base model, preferences).
    • Example: AI2's work showed varying results for DPO and PO depending on the SFT method.
  • Overoptimization: As the policy is optimized, the reward model diverges from real human preferences.
    • Occurs due to the noisiness and complexities of human preferences.
    • Evidenced by studies showing overoptimization with human and noisy AI feedback, but not with clean AI feedback.
  • Calibration Issues: RLHF models are often less calibrated and more overconfident than models trained with supervised learning.

Transition to Reinforcement Learning from Verifiable Rewards

  • Motivation: Human feedback is difficult to optimize and scale.
  • Alternative: Apply RL in domains with true, quickly evaluable rewards (e.g., AlphaGo, AlphaFold).
  • Goal: Bring the successes of RL to language modeling by focusing on domains with verifiable rewards.

Proximal Policy Optimization (PO)

  • Policy Gradient: Optimizes the expected reward under the policy through gradient descent.
    • Equation: Expectation under the current policy of the reward, taking gradient steps to increase/decrease probability based on reward sign.
  • Inefficiency: Purely on-policy, requiring a rollout for each gradient step.
  • TRPO (Trust Region Policy Optimization): Allows updates from stale samples using importance sampling correction.
  • PO: Clips advantages instead of using KL divergences to keep policies close.
    • Equation: PO clip objective.
  • Implementation Details: Requires a value function to compute expected rewards and lower variance.
  • Complexity: PO in practice is complex, with many implementation details and variants.
  • Reward Shaping: Constructing per-token losses to provide easier-to-learn signals for the RL algorithm.
    • KL terms are computed per token, while the true reward is computed at the last token.
  • Generalized Advantage Estimation (GAE): Used for variance reduction in policy gradients.

Generalized Robust Policy Optimization (GRPO)

  • Motivation: Simplify PO by removing the value model and GAE.
  • Advantage Estimator Replacement: Replaces GAE with a z-score of rewards within a group (responses to the same input).
    • Equation: Advantage = (Reward - Mean Reward) / Standard Deviation of Rewards.
  • Group Definition: A set of responses to the same input question.
  • Baseline: The mean reward within a group serves as a natural baseline, accounting for problem difficulty.
  • KL Divergence Estimation: Uses a control variate scheme to reduce variance in KL divergence estimation.
  • Online Case: In the pure online case, clipping disappears, and the algorithm simplifies to policy gradients.
  • Implementation: Involves computing rewards, normalizing by group, computing KL terms, and updating gradients.
  • Advantage Computation: Simple z-score calculation with a small epsilon added to the standard deviation for numerical stability.
  • Performance: GRPO works well, as demonstrated in the DeepSeek Math paper.

GRPO vs. PO: Differences and Analysis

  • Key Difference: Replacement of the advantage estimator with the z-score.
  • Baseline Validity: Dividing by the standard deviation breaks the contract of a baseline needing to be a zero-mean variable independent of the draw.
  • Length Normalization: Dividing rewards by output length is also problematic according to the policy gradient theorem.
  • Standard Deviation Impact: Amplifies rewards when the standard deviation is small (problems are too easy or too hard), biasing towards these problems.
  • Length Normalization Impact: Incentivizes long responses when the model gets a question wrong and short responses when it gets it right, leading to aggressive "BS-ing."
  • Fixes: Removing standard deviation division and length normalization can lead to shorter output lengths and higher rewards.

Case Studies: R1, Kimmy 1.5, and Quen 3

R1

  • Significance: Replicates qualitative properties of the 01 recipe with extreme simplicity.
  • Starting Point: Builds on DeepSeek Math, using GRPO with outcome supervision.
  • R10 (Controlled Setting): RL on top of a pre-trained and mid-trained model.
    • Rewards: Accuracy and format rewards.
    • Results: Performance close to 01.
    • Observations: Coot length increases predictably, and backtracking ("aha moment") emerges.
  • R1 (Unrestricted Setting): Includes supervised fine-tuning and post-training.
    • SFT Initialization: Fine-tunes on long coots to encourage reasoning.
    • Language Consistency Reward: Prevents language mixing in coots.
    • Post-Training: Instruction tuning and pairwise preference tuning.
  • Performance: Matches 01 performance across various tasks.
  • Distillation: Distills knowledge into smaller models by fine-tuning them on chains of thought from the larger model.
  • Negative Results: PRMs and MCTS did not significantly improve performance.

Kimmy 1.5

  • Significance: Achieves similar results to R1 using outcome-based rewards and a different RL algorithm.
  • Data Curation:
    • Balances across domains.
    • Excludes multiple-choice and true/false questions.
    • Filters for difficulty using a pass rate threshold (excludes examples that the base model can easily answer).
  • RL Algorithm:
    • Maximizes rewards while staying close to the base policy.
    • Uses a squared loss to drive the left and right sides of an optimal policy equation close together.
    • Employs a baseline reward (average reward within the batch).
  • Length Reward:
    • Incentivizes short coots while maintaining performance.
    • Reward is based on the position of the length within the range of lengths in the batch.
    • Turned on later in training to avoid stalling RL.
  • Curriculum:
    • Uses assigned difficulty labels and progresses from easy to hard.
    • Samples problems proportional to one minus the success rate.
  • Reward Models:
    • Uses ground truth solutions for code problems.
    • Uses a reward model for math problems to compare LM outputs to human-written answers.
  • Infrastructure:
    • Uses separate workers for RL updates and inference.
    • Employs message passing to transfer weights and data.
  • Scaling: Shows scaling of performance and length of responses over iterations.

Quen 3

  • Significance: Most recent RL for reasoning model, building upon previous works.
  • Overall Picture: Similar to R1 and Kimmy, with SFT, reasoning RL, and RLHF.
  • Data Curation:
    • Filters for difficulty using best-of-N sampling.
    • Removes data similar to validation data.
    • Manually filters SFT data for correct reasoning.
  • RL Data Size: Uses only 3,995 examples for RL.
  • Thinking Mode Fusion:
    • Trains a single model to both think and not think.
    • Fine-tunes the model with think and no-think tags.
    • Allows for controlled inference by terminating the thinking process.
  • Ablation: Shows performance at different stages (reasoning RL, thinking mode fusion, general RL).
  • Trade-offs: Suggests a trade-off between optimizing for general-purpose instruction following and math performance.

Conclusion

RL is a powerful tool for language models, especially in domains with verifiable rewards. GRPO is a simple algorithm that enables RL on these domains. Successful recipes involve SFT, reasoning RL, and RLHF, with various implementation tricks and trade-offs to consider.

Chat with this Video

AI-Powered

Hi! I can answer questions about this video "Stanford CS336 Language Modeling from Scratch | Spring 2025 | Lecture 16: Alignment - RL". What would you like to know?

Chat is based on the transcript of this video and may not be 100% accurate.

Related Videos

Ready to summarize another video?

Summarize YouTube Video