Scaling Up (Part 2)
By Google for Developers
Key Concepts
- SPMD (Single Program, Multiple Data): A parallel programming paradigm where a single program is executed on multiple processors, each operating on different subsets of the data.
- Mesh: In JAX, a
jax.sharding.Meshobject defines the hardware topology, specifying how devices are arranged and named. - Partition Spec: A
jax.sharding.PartitionSpecdescribes how a tensor's dimensions should be partitioned across the devices defined in aMesh. - JAX JIT (Just-In-Time Compilation): A transformation that compiles Python and NumPy code into optimized XLA (Accelerated Linear Algebra) computations.
- Flax NNX (Neural Network eXtensions): A library for building neural networks in JAX, offering stateful modules that are compatible with JAX transformations.
- PI Tree: A data structure in JAX that represents nested collections of arrays, such as dictionaries, lists, and tuples.
- Sharding Metadata: Annotations attached to variables (parameters) that specify how they should be partitioned.
- Sharding Rules: A mechanism to map logical axis names (e.g., 'batch', 'hidden') to physical mesh axes (e.g., 'data', 'model'), promoting portability.
jax.lax.with_sharding_constraint: A JAX primitive that enforces sharding specifications on a given array or PI tree during compilation.nx.jitandnx.grad: NNX wrappers for JAX'sjitandgradtransformations, simplifying the handling of stateful modules.nx.with_metadata: An NNX helper function to attach metadata, including sharding hints, to initializers.nx.state: An NNX function to extract the functional state (PI tree) from an NNX module.nx.spmd.get_partition_spec: An NNX utility to extract sharding metadata from a PI tree.nx.update: An NNX function to merge updated state back into a mutable module object.with meshcontext: A block in JAX that provides the necessary context for sharding operations to be applied across specified devices.- OOM (Out Of Memory): An error that occurs when a program requires more memory than is available on a device.
Scaling Flax NNX Models with JAX SPMD
This document details the process of scaling Flax NNX models using JAX's SPMD capabilities, focusing on integrating JAX's sharding primitives with Flax NNX, particularly the critical sharded initialization pattern.
Recap of Flax NNX and JAX Integration
- Flax NNX Modules: Unlike pure JAX, NNX modules are stateful Python objects, similar to PyTorch modules, holding parameters and state as attributes.
- Stateful Nature and JAX Compatibility: A key development in NNX (version 0.11+) is that NNX modules are now native JAX PI trees. This allows JAX transformations to directly understand and operate on them, bridging the gap between NNX's mutability and JAX's functional nature.
- Reconciling Mutability and Functional Transformations: NNX offers two primary approaches:
- Functional API: Explicitly splitting modules into static definitions and dynamic state (a JAX-compatible PI tree). This state is then passed through JAX transformations (
jax.jit,jax.grad), and the results are merged back. - NNX Wrappers: Using NNX's own wrappers like
nx.jitandnx.grad. These handle the split-merge process internally, offering a more seamless user experience, though potentially with minor overhead.
- Functional API: Explicitly splitting modules into static definitions and dynamic state (a JAX-compatible PI tree). This state is then passed through JAX transformations (
Integrating NNX and SPMD for Sharding
The core mechanism for integrating sharding involves attaching PartitionSpec hints directly to model parameters as metadata within NNX modules.
- Attaching Sharding Hints:
nx.with_metadata: The preferred method involves wrapping the initializer withnx.with_metadata, passing ashardingargument containing thePartitionSpectuple.- Direct Annotation: Alternatively, the
shardingargument can sometimes be passed directly tonx.param.
- Metadata Storage: This process stores the
PartitionSpectuple in ashardingattribute on the variable's state. Crucially, this metadata does not shard the parameter at this stage; it serves as a hint for future compilation.
The Critical Sharded Initialization Pattern
A common pitfall when initializing large NNX models is that JAX might attempt to create all parameters on a single default device, leading to "out of memory" (OOM) errors. The solution is to perform initialization and sharding application within a JIT-compiled function.
Step-by-step process for sharded initialization:
- Define a JIT-compiled function: Use
nx.jitfor convenience to wrap the initialization process. - Instantiate the module: Inside the JIT function, create an instance of the NNX module. At this point, parameters are created but likely reside on device zero.
- Extract functional state: Use
nx.stateto obtain the module's state as a JAX-compatible PI tree. - Extract
PartitionSpecobjects: Employnx.spmd.get_partition_specto traverse the extracted state PI tree. This utility function identifies and collects all theshardingmetadata (thePartitionSpectuples) previously attached to the parameters, creating a parallel PI tree ofPartitionSpecobjects. - Apply sharding constraints: Use
jax.lax.with_sharding_constraint. This function takes the original state PI tree and the newly created PI tree ofPartitionSpecobjects. It instructs the JAX compiler to ensure that the final output state conforms to these specified sharding configurations. - Update the module with sharded state: Utilize
nx.update. This function takes the sharded state (which JAX/XLA will ensure is correctly distributed across devices during execution) and merges it back into the original mutable NNX module object. - Return the sharded model: The JIT-compiled function returns the updated module, now containing sharded parameters.
Execution Context:
- The entire
create_sharded_modelfunction (or its equivalent) must be called within awith meshblock or ameshcontext. This provides JAX with the necessary information about the hardware topology and device arrangement to fulfill the sharding constraints. - The resulting "sharded model" will have its parameters physically distributed across the defined accelerator mesh, effectively avoiding single-device OOM errors.
Using Sharding Rules for Portability
To enhance model portability and adaptability to different hardware configurations, NNX introduces "sharding rules."
- Logical vs. Physical Axes: Instead of hard-coding physical device names (e.g., 'data', 'model') directly into module definitions, use logical axis names (e.g., 'batch', 'sequence', 'embed', 'hidden').
- Sharding Rules Mapping: A separate mapping, defined as
sharding_rules, translates these logical names to the actual physical mesh axes. - Applying Sharding Rules: These rules can be provided during variable initialization or attached later.
- Benefits: This approach decouples model code from specific hardware layouts, making it easier to adapt to different sharding strategies by simply modifying the
sharding_ruleswithout altering the model definition itself.
Example of Sharding Rules:
# Define sharding rules mapping logical to physical axes
sharding_rules = (('batch', 'data'), ('hidden', 'model'))
# Use logical names in NNX module definition
# ...
param = nx.param(
"my_param",
nx.with_metadata(
jax.random.normal,
shape=(1024, 1024),
sharding=PartitionSpec('batch', 'hidden') # Using logical names
)
)
# ...
In this example, batch in the PartitionSpec will be mapped to the data axis of the mesh, and hidden will be mapped to the model axis.
Conclusion and Next Steps
The core takeaway is the successful integration of JAX's sharding primitives with Flax NNX, specifically addressing the critical challenge of safely initializing massive sharded models. By annotating parameters with sharding metadata and executing the initialization within a JIT-compiled function under a with mesh context, the single-device OOM problem is solved, resulting in a model with parameters physically distributed across the accelerator mesh.
The next episode will build upon this foundation by demonstrating:
- The complete distributed training loop.
- Efficient data loading using libraries like
grain. - Sharded checkpointing with tools like
orbax. - A practical, end-to-end example of sharding a GPT-2 style transformer block.
Chat with this Video
AI-PoweredHi! I can answer questions about this video "Scaling Up (Part 2)". What would you like to know?