Scaling Up (Part 2)

By Google for Developers

Share:

Key Concepts

  • SPMD (Single Program, Multiple Data): A parallel programming paradigm where a single program is executed on multiple processors, each operating on different subsets of the data.
  • Mesh: In JAX, a jax.sharding.Mesh object defines the hardware topology, specifying how devices are arranged and named.
  • Partition Spec: A jax.sharding.PartitionSpec describes how a tensor's dimensions should be partitioned across the devices defined in a Mesh.
  • JAX JIT (Just-In-Time Compilation): A transformation that compiles Python and NumPy code into optimized XLA (Accelerated Linear Algebra) computations.
  • Flax NNX (Neural Network eXtensions): A library for building neural networks in JAX, offering stateful modules that are compatible with JAX transformations.
  • PI Tree: A data structure in JAX that represents nested collections of arrays, such as dictionaries, lists, and tuples.
  • Sharding Metadata: Annotations attached to variables (parameters) that specify how they should be partitioned.
  • Sharding Rules: A mechanism to map logical axis names (e.g., 'batch', 'hidden') to physical mesh axes (e.g., 'data', 'model'), promoting portability.
  • jax.lax.with_sharding_constraint: A JAX primitive that enforces sharding specifications on a given array or PI tree during compilation.
  • nx.jit and nx.grad: NNX wrappers for JAX's jit and grad transformations, simplifying the handling of stateful modules.
  • nx.with_metadata: An NNX helper function to attach metadata, including sharding hints, to initializers.
  • nx.state: An NNX function to extract the functional state (PI tree) from an NNX module.
  • nx.spmd.get_partition_spec: An NNX utility to extract sharding metadata from a PI tree.
  • nx.update: An NNX function to merge updated state back into a mutable module object.
  • with mesh context: A block in JAX that provides the necessary context for sharding operations to be applied across specified devices.
  • OOM (Out Of Memory): An error that occurs when a program requires more memory than is available on a device.

Scaling Flax NNX Models with JAX SPMD

This document details the process of scaling Flax NNX models using JAX's SPMD capabilities, focusing on integrating JAX's sharding primitives with Flax NNX, particularly the critical sharded initialization pattern.

Recap of Flax NNX and JAX Integration

  • Flax NNX Modules: Unlike pure JAX, NNX modules are stateful Python objects, similar to PyTorch modules, holding parameters and state as attributes.
  • Stateful Nature and JAX Compatibility: A key development in NNX (version 0.11+) is that NNX modules are now native JAX PI trees. This allows JAX transformations to directly understand and operate on them, bridging the gap between NNX's mutability and JAX's functional nature.
  • Reconciling Mutability and Functional Transformations: NNX offers two primary approaches:
    1. Functional API: Explicitly splitting modules into static definitions and dynamic state (a JAX-compatible PI tree). This state is then passed through JAX transformations (jax.jit, jax.grad), and the results are merged back.
    2. NNX Wrappers: Using NNX's own wrappers like nx.jit and nx.grad. These handle the split-merge process internally, offering a more seamless user experience, though potentially with minor overhead.

Integrating NNX and SPMD for Sharding

The core mechanism for integrating sharding involves attaching PartitionSpec hints directly to model parameters as metadata within NNX modules.

  • Attaching Sharding Hints:
    • nx.with_metadata: The preferred method involves wrapping the initializer with nx.with_metadata, passing a sharding argument containing the PartitionSpec tuple.
    • Direct Annotation: Alternatively, the sharding argument can sometimes be passed directly to nx.param.
  • Metadata Storage: This process stores the PartitionSpec tuple in a sharding attribute on the variable's state. Crucially, this metadata does not shard the parameter at this stage; it serves as a hint for future compilation.

The Critical Sharded Initialization Pattern

A common pitfall when initializing large NNX models is that JAX might attempt to create all parameters on a single default device, leading to "out of memory" (OOM) errors. The solution is to perform initialization and sharding application within a JIT-compiled function.

Step-by-step process for sharded initialization:

  1. Define a JIT-compiled function: Use nx.jit for convenience to wrap the initialization process.
  2. Instantiate the module: Inside the JIT function, create an instance of the NNX module. At this point, parameters are created but likely reside on device zero.
  3. Extract functional state: Use nx.state to obtain the module's state as a JAX-compatible PI tree.
  4. Extract PartitionSpec objects: Employ nx.spmd.get_partition_spec to traverse the extracted state PI tree. This utility function identifies and collects all the sharding metadata (the PartitionSpec tuples) previously attached to the parameters, creating a parallel PI tree of PartitionSpec objects.
  5. Apply sharding constraints: Use jax.lax.with_sharding_constraint. This function takes the original state PI tree and the newly created PI tree of PartitionSpec objects. It instructs the JAX compiler to ensure that the final output state conforms to these specified sharding configurations.
  6. Update the module with sharded state: Utilize nx.update. This function takes the sharded state (which JAX/XLA will ensure is correctly distributed across devices during execution) and merges it back into the original mutable NNX module object.
  7. Return the sharded model: The JIT-compiled function returns the updated module, now containing sharded parameters.

Execution Context:

  • The entire create_sharded_model function (or its equivalent) must be called within a with mesh block or a mesh context. This provides JAX with the necessary information about the hardware topology and device arrangement to fulfill the sharding constraints.
  • The resulting "sharded model" will have its parameters physically distributed across the defined accelerator mesh, effectively avoiding single-device OOM errors.

Using Sharding Rules for Portability

To enhance model portability and adaptability to different hardware configurations, NNX introduces "sharding rules."

  • Logical vs. Physical Axes: Instead of hard-coding physical device names (e.g., 'data', 'model') directly into module definitions, use logical axis names (e.g., 'batch', 'sequence', 'embed', 'hidden').
  • Sharding Rules Mapping: A separate mapping, defined as sharding_rules, translates these logical names to the actual physical mesh axes.
  • Applying Sharding Rules: These rules can be provided during variable initialization or attached later.
  • Benefits: This approach decouples model code from specific hardware layouts, making it easier to adapt to different sharding strategies by simply modifying the sharding_rules without altering the model definition itself.

Example of Sharding Rules:

# Define sharding rules mapping logical to physical axes
sharding_rules = (('batch', 'data'), ('hidden', 'model'))

# Use logical names in NNX module definition
# ...
    param = nx.param(
        "my_param",
        nx.with_metadata(
            jax.random.normal,
            shape=(1024, 1024),
            sharding=PartitionSpec('batch', 'hidden') # Using logical names
        )
    )
# ...

In this example, batch in the PartitionSpec will be mapped to the data axis of the mesh, and hidden will be mapped to the model axis.

Conclusion and Next Steps

The core takeaway is the successful integration of JAX's sharding primitives with Flax NNX, specifically addressing the critical challenge of safely initializing massive sharded models. By annotating parameters with sharding metadata and executing the initialization within a JIT-compiled function under a with mesh context, the single-device OOM problem is solved, resulting in a model with parameters physically distributed across the accelerator mesh.

The next episode will build upon this foundation by demonstrating:

  • The complete distributed training loop.
  • Efficient data loading using libraries like grain.
  • Sharded checkpointing with tools like orbax.
  • A practical, end-to-end example of sharding a GPT-2 style transformer block.

Chat with this Video

AI-Powered

Hi! I can answer questions about this video "Scaling Up (Part 2)". What would you like to know?

Chat is based on the transcript of this video and may not be 100% accurate.

Related Videos

Ready to summarize another video?

Summarize YouTube Video