Checkpointing Flax NNX Models with Orbax (Part 2)
By Google for Developers
Share:
Key Concepts
- Optimizer State: Information maintained by an optimizer beyond model parameters, such as momentum, adaptive learning rates, and step counts.
nnx.optimizer: A Flax NNX class that bundles a model, optimizer definition, and its state.nnx.variables: NNX's mechanism for managing state, used by both models and optimizers.nnx.state: A function to extract the state of an NNX object.ocp.args.composite: An Orbax argument for saving multiple items (e.g., model parameters and optimizer state) together under different names within a single checkpoint.nnx.splitwithnnx.paramfilter: Used to extract only model parameters, excluding other model state like batch statistics.ocp.args.standards.save: Wraps PyTrees for standard saving.nnx.evalshape: Used to create abstract versions of models and optimizers for restoration, preserving structure and type.ocp.args.standards.restore: Used for restoring state using abstract templates.- SPMD (Single Program, Multiple Data): Jax's paradigm for distributed computation where the same program runs on multiple devices with different data chunks.
- Partition Specs: Instructions specifying how arrays should be partitioned or sharded across a device mesh.
- Sharded JAX Arrays: JAX arrays distributed across multiple devices.
ocp.checkpoint_utils.construct_restore_args: Helper for creating restore arguments, especially for PyTree restore.jax.lax.with_sharding_constraint: A JAX function to apply sharding specifications to abstract state leaves.- Asynchronous Checkpointing: Saving operations performed in a background thread to avoid blocking the main training loop.
- Atomic Saves: Orbax's guarantee that a checkpoint directory is only finalized after all files are written, preventing corrupted states.
ocp.args.json_save: For saving non-PyTree data like JSON.- Tensor Store: An underlying storage system often used by Orbax for efficient I/O, especially with cloud storage.
Saving Optimizer State and Handling Distributed Models with Orbax
This document details advanced checkpointing techniques with Orbax, focusing on saving optimizer state and managing distributed, sharded models in JAX/Flax.
1. Saving Optimizer State
- Problem: During training, optimizers (like Adam or SGD) maintain their own state (e.g., momentum vectors, adaptive learning rates, step counts) in addition to model parameters.
- Solution:
- Flax's
nnx.optimizerclass bundles the model, optimizer definition, and its state. - Crucially,
nnx.optimizermanages its internal state usingnnx.variables, similar to models. - This allows extracting the optimizer's state using
nnx.stateand saving it with Orbax.
- Flax's
- Key Process:
- Extract Model Parameters: Use
nnx.splitwith thennx.paramfilter to get only model parameters if other states (like batch stats) are handled separately. - Extract Optimizer State: Use
nnx.stateon the optimizer object to get its entire state. - Combine for Saving:
- Define a checkpoint directory.
- Use
ocp.args.compositeto package multiple save items under distinct names (e.g., "params", "optimizer"). - Create a dictionary
save_itemswhere keys are desired names and values are the corresponding states wrapped withocp.args.standards.save. - Pass this dictionary unpacked into
ocp.args.compositewithin themanager.savecall. - Example:
manager.save(step, args=ocp.args.composite(params=ocp.args.standards.save(model_params), optimizer=ocp.args.standards.save(optimizer_state)))
- Wait and Close: Standard Orbax procedure to finalize the save.
- Extract Model Parameters: Use
2. Restoring Optimizer State
- Key Process:
- Create Abstract Templates:
- Generate abstract versions of both the model and the optimizer using
nnx.evalshape. This ensures the structure and type (e.g., the specific Optax optimizer) match what was saved. - From these abstract objects, obtain the necessary graph definitions and abstract state PyTrees that will serve as templates for restoration (e.g.,
abs_param_state,abs_optimizer_state).
- Generate abstract versions of both the model and the optimizer using
- Find Latest Step: Identify the latest checkpoint step to restore.
- Prepare Restore Arguments: Similar to saving, create a dictionary for
ocp.args.composite, but this time usingocp.args.standards.restorewith the abstract state templates as targets. - Restore: Call
manager.restorewith the prepared arguments. This returns a dictionary (restored_items) containing the loaded parameter and optimizer state PyTrees. - Update Concrete Instances:
- Create new, concrete instances of the model and optimizer.
- Use
nnx.updateto populate these fresh instances with the restored data from therestored_itemsdictionary.
- Create Abstract Templates:
3. Handling Distributed Sharded Models
-
Context: Modern large-scale training distributes computation across many accelerators using JAX's SPMD paradigm. Arrays are partitioned or "sharded" across a logical device mesh using partition specs.
-
Integration with NNX: Sharding instructions can be attached as metadata to
nnx.variableswhen defining a model. -
Orbax and Sharding:
- Orbax understands and saves sharded JAX arrays efficiently.
- Restoration Challenge: Orbax needs to know how to reassemble these sharded arrays on potentially different device topologies at restore time.
-
Restoring Sharded State:
- Create Sharding-Aware Abstract Template:
- Start with an abstract model created via
nnx.evalshape. - Split it to get the plain abstract state.
- Extract the desired sharding specifications (e.g., using
nnx.get_partition_spec). - Apply these specifications to the abstract state leaves using
jax.lax.with_sharding_constraint, typically within ajax.jitfunction and the mesh context. This creates an abstract state where shape, dtype, and structure also encode sharding information.
- Start with an abstract model created via
- Restore with Sharding Information: Pass this sharding-aware abstract state directly to Orbax's
restorefunction. Orbax uses this information to reconstruct the distributed arrays correctly on the target devices. - Merge Restored State: Merge the restored state back into the model.
- Create Sharding-Aware Abstract Template:
-
Code Snippet Illustration:
# Helper function to get abstract state template with target sharding def get_abstract_state_with_sharding(model, mesh, target_sharding): abstract_model = nnx.evalshape(model) abstract_state = nnx.split(abstract_model) # Get plain abstract state # Extract sharding specs (example, might need adjustment) sharding_specs = nnx.get_partition_spec(abstract_state) # Apply sharding constraints with mesh: sharded_abstract_state = jax.lax.with_sharding_constraint( abstract_state, sharding_specs ) return sharded_abstract_state # ... later in restore process ... # restored_state = manager.restore(latest_step, args=ocp.args.composite( # params=ocp.args.standards.restore(sharding_aware_abstract_params), # optimizer=ocp.args.standards.restore(sharding_aware_abstract_optimizer) # )) # nnx.update(model, restored_state['params']) # nnx.update(optimizer, restored_state['optimizer'])
4. Efficiency and Robustness Features
- Asynchronous Checkpointing:
- Enables saving operations to run in a background thread, preventing blocking of the main training loop.
- Requires explicit waiting for completion later.
- Atomic Saves:
- The
CheckpointManagerguarantees that a checkpoint directory is only finalized once all its files are successfully written. - This prevents corrupted states from incomplete saves.
- The
- Saving Non-PyTree Data:
- Features like
ocp.args.json_saveallow saving data that isn't a PyTree (e.g., dataset iterators) as JSON alongside model state within a composite checkpoint.
- Features like
- Underlying Storage:
- Orbax often utilizes Tensor Store, particularly for cloud storage, to optimize the reading and writing of large distributed arrays.
5. Conclusion and Key Takeaways
- Flax NNX: Provides an object-oriented approach to defining JAX models.
- Orbax: Offers robust tools for checkpointing NNX state.
- Core Workflow:
- Saving: Split the model to extract state, then use
ocp.args.compositeto save multiple items (parameters, optimizer state) together. - Restoring: Use
nnx.evalshapeto create abstract templates, thenocp.args.standards.restorewith these templates, and finallynnx.updateto merge restored data into concrete model/optimizer instances.
- Saving: Split the model to extract state, then use
- Distributed Training:
- Pay close attention to sharding.
- Use Orbax utilities with correct abstract state and mesh context to ensure sharding is handled properly during restoration.
- Additional Features: Asynchronous saving, atomic commits, and JSON saving enhance efficiency and reliability.
- Resources: The video points to coding exercises, quick reference docs, slides, a YouTube playlist for the "Learning Jax" series, a Discord community, and documentation for Jax, Flax, and the Jax AI stack.
Chat with this Video
AI-PoweredHi! I can answer questions about this video "Checkpointing Flax NNX Models with Orbax (Part 2)". What would you like to know?
Chat is based on the transcript of this video and may not be 100% accurate.