Sharding Large Models with Tensor Parallelism

You,large model engineering
Photo

In a previous post, we covered how to parallelize the data in a large model across multiple GPUs. We saw that if you can't fit an entire batch on a single GPU, you can split the batch across multiple devices. But what happens when the model is so big that even with a batch of 1 you can't fit it on a single GPU? Or the model fits but trains too slowly? We can accelerate training by parallelizing the model itself. This technique is called model parallelism. In this post, we'll mainly cover a type of model parallelism called tensor parallelism which is commonly used to train large models today.

Pipeline parallelism and its limitations

There are many ways to parallelize a model. One of the most intuitive approaches is called pipelined model parallelism. In this approach, the model is split into multiple stages, and each stage is assigned to a different device. The output of one stage is fed as input to the next stage. For example, if you're pipeline parallelizing an 8 layer MLP across 8 devices, you would place a layer on each device. The output of the first layer would be fed as input to the second layer, and so on. The output of the last layer would be the output of the MLP (see the PipeDream (opens in a new tab) paper for details). While simple, this approach is not very efficient and suffers from idle time when machines are waiting for other machines to finish their stages. This is because the pipeline is waiting for a stage to finish in both the forward and backward pass. Machine idling, referred to as a bubble, is inefficient because the machine is not being utilized during a bubble.

Photo

Efficient model sharding with tensor parallelism

How can we parallelize a model more efficiently? Rather than parallelizing the model layer by layer, we can split the model into shards that are distributed across multiple devices and execute in parallel. This approach is called tensor parallelism or model sharding. It is usually more efficient than pipelining but can be more challenging to implement, because it requires careful consideration of how to split different parts of the model. In this post, we'll cover how to parallelize an MLP using tensor parallelism with an approach popularized by the Megatron paper (opens in a new tab).

Let's consider a 2 layer MLP with layer 1 parameterized by matrix A and layer 2 parameterized by matrix B. Now let's say we have a batch of data XX and we want to compute the output ZZ. In a typical MLP, we would compute Y=f(XA)Y = f(XA) and Z=YBZ = YB where ff is the activation function. So how can we parallelize this model?

Photo

We can split matrix A into two equal parts column-wise, and matrix B into two equal parts row-wise. We can represent matrix A as A=(A1A2)A =\begin{pmatrix} A_1 & A_2 \end{pmatrix}, and matrix B as B=(B1B2)B = \begin{pmatrix} B_1 \\ B_2 \end{pmatrix}. Here, A1A_1 and A2A_2 are the two equal parts of matrix A, and B1B_1 and B2B_2 are the two equal parts of matrix B. If we substitute these matrices into the computation, we get:

Yi=f(XAi)fori{1,2}Y_i = f(XA_i) \quad \text{for} \quad i \in \{1, 2\}

Zi=YiBifori{1,2}Z_i = Y_i B_i \quad \text{for} \quad i \in \{1, 2\}

and Z=Z1+Z2Z = Z_1 + Z_2 where the two computations that were happening on separate devices are joined with an allreduce operation. This relation follows from linear algebra, because AB=A1B1+A2B2AB = A_1B_1 + A_2B_2 so these two computations can be done independently in parallel.

We have just described a tensor parallel forward pass. But what about the backward pass? The backward pass is a bit more complicated, because we need to compute the gradients of the loss with respect to the parameters. In the forward pass, we computed Y=f(XA)Y = f(XA) and Z=YBZ = YB. So the gradients of the loss with respect to AiA_i and BiB_i are:

LA1=LY1Y1A1\frac{\partial L}{\partial A_1} = \frac{\partial L}{\partial Y_1} \frac{\partial Y_1}{\partial A_1}

LA2=LY2Y2A2\frac{\partial L}{\partial A_2} = \frac{\partial L}{\partial Y_2} \frac{\partial Y_2}{\partial A_2}

LB1=LZ1Z1B1\frac{\partial L}{\partial B_1} = \frac{\partial L}{\partial Z_1} \frac{\partial Z_1}{\partial B_1}

LB2=LZ2Z2B2\frac{\partial L}{\partial B_2} = \frac{\partial L}{\partial Z_2} \frac{\partial Z_2}{\partial B_2}

where LY1\frac{\partial L}{\partial Y_1} and LY2\frac{\partial L}{\partial Y_2} are the gradients of the loss with respect to the output of the first and second layers, respectively. We can compute these gradients with the chain rule. For example, LY1=LZ1Z1Y1\frac{\partial L}{\partial Y_1} = \frac{\partial L}{\partial Z_1} \frac{\partial Z_1}{\partial Y_1} and backpropagate the error from the global loss value computed after the forward pass.

Illustrative example of tensor parallelism with Numpy

To make things more concrete, let's see how we can implement a tensor parallel MLP in Numpy. We'll use the same example as before, where we have a 2 layer MLP with layer 1 parameterized by matrix A and layer 2 parameterized by matrix B. We'll use the same notation as before, where A1A_1 and A2A_2 are the two equal parts of matrix A, and B1B_1 and B2B_2 are the two equal parts of matrix B. We'll also use the same activation function ff as before.

We'll start by defining a function that splits a matrix into two equal parts column-wise. We'll use the numpy.split (opens in a new tab) function to do this. We'll also define a function that splits a matrix into two equal parts row-wise.

import numpy as np
 
def split_columnwise(A, num_splits):
    return np.split(A, num_splits, axis=1)
 
def split_rowwise(A, num_splits):
    return np.split(A, num_splits, axis=0)

Now, let's define a function that computes the forward pass of a tensor parallel MLP. We'll use the numpy.dot (opens in a new tab) function to compute the matrix multiplication. We'll also use the numpy.sum (opens in a new tab) function to compute the sum of the two parts of the output.

def normal_forward_pass(X, A, B, f):
  Y = f(np.dot(X, A))
  Z = np.dot(Y, B)
  return Z
 
def tensor_parallel_forward_pass(X, A, B, f):
    A1, A2 = split_columnwise(A, 2)
    B1, B2 = split_rowwise(B, 2)
    Y1 = f(np.dot(X, A1))
    Y2 = f(np.dot(X, A2))
    Z1 = np.dot(Y1, B1)
    Z2 = np.dot(Y2, B2)
    Z = np.sum([Z1, Z2], axis=0)
    return Z

We can now compute the forward pass of the MLP. We'll use the numpy.random.randn (opens in a new tab) function to generate random matrices.

X = np.random.randn(2, 2)
A = np.random.randn(2, 2)
B = np.random.randn(2, 2)
Z = tensor_parallel_forward_pass(X, A, B, np.tanh)
Z_normal = normal_forward_pass(X, A, B, np.tanh)
print(np.allclose(Z, Z_normal)) # outputs: True

Suppose we are doing regression with this MLP and that the true targets are [0.5,0.5][-0.5, 0.5]. We can compute the loss as follows.

target = np.array([[-0.5, 0.5], [-0.5, 0.5]])
# loss function
def L(Z, Y):
    return np.sum((Z - Y) ** 2)
loss = L(Z, target)

Now we can also compute the backward pass of the MLP with respect to this loss. First, let's forget about tensor parallelism and derive the backpropagation equations in code for the normal MLP. Specifically, we want to know how to adjust the weight matrices AA and BB to reduce the loss, or in other words, how to compute the gradients of the loss with respect to AA and BB.

def normal_backward_pass(X, A, B, f):
    # recompute forward pass to get activations
    Y = f(np.dot(X, A)) 
    Z = np.dot(Y, B) 
    # compute gradients
    # gradient of loss with respect to Z
    dLdZ = 2 * (Z - Y)
    # gradient of loss with respect to B via chain rule
    # dLdB = dLdZ * dZdB = dLdZ * Y = np.dot(Y.T, dLdZ)
    dLdB = np.dot(Y.T, dLdZ)
    # gradient of loss with respect to A via chain rule
    # dLdY = dLdZ * dZdY = dLdZ * B = np.dot(dLdZ, B.T)
    dLdY = np.dot(dLdZ, B.T)
    # dLdA = dLdY * dYdA = dLdY * (1 - Y ** 2) = np.dot(X.T, dLdY * (1 - Y ** 2))
    # (1 - Y ** 2) is the derivative of the activation function f = np.tanh
    # derivative of tanh is 1 - tanh ** 2
    dLdA = np.dot(X.T, dLdY * (1 - Y ** 2))
    return dLdA, dLdB, Z

In the above code, we're using the chain rule to compute the gradients LA\frac{\partial L}{\partial A} and LB\frac{\partial L}{\partial B}, where LL is the loss. Since the output of the second layer is used directly for prediction the gradient LB\frac{\partial L}{\partial B} is straightforward to compute. The tricky part is the LA\frac{\partial L}{\partial A} term because it requires differentiating through the activation function ff. In this case, we're using the tanh\tanh activation function, so we can use the fact that f(x)=tanh(x)=exexex+exf(x) = \tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} to compute the derivative of f(x)=tanh(x)f(x) = \tanh(x) as 1tanh2(x)=1f2(x) 1 - \tanh^2(x) = 1 - f^2(x).

To see how we arrived at the equation for LA\frac{\partial L}{\partial A}, let's step through the chain rule mathematically. First, we have that LA=LYYA\frac{\partial L} {\partial A} = \frac{\partial L} {\partial Y} \frac{\partial Y} {\partial A}. The term LY\frac{\partial L} {\partial Y} is straightforward to compute, so we'll focus on YA\frac{\partial Y} {\partial A}. Let's express this derivative in terms of the output of the activation function YY and the input to the activation function XX. We have that Y=f(XA)Y = f(XA), so YA=f(XA)X\frac{\partial Y} {\partial A} = f'(XA)X. Now, we can substitute YY for f(XA)f(XA) to get YA=f(XA)X=(1f2(XA))X\frac{\partial Y} {\partial A} = f'(XA)X = (1 -f^2(XA))X. This is the same equation we used in the code above.

def tensor_parallel_backward_pass(X, A, B, f):
    # recompute forward pass to get activations
    A1, A2 = split_columnwise(A, 2) 
    B1, B2 = split_rowwise(B, 2) 
    Y1 = f(np.dot(X, A1)) 
    Y2 = f(np.dot(X, A2)) 
    Z1 = np.dot(Y1, B1) 
    Z2 = np.dot(Y2, B2) 
    Z = Z1 + Z2 
    
    # compute gradients, same logic as from normal_backward_pass
    # this one has to be done without parallelism
    # since dLdZ1 = dLdZ2 = dLdZ
    dLdZ = 2 * (Z - np.concatenate([Y1, Y2], axis=1))
    
    dLdZ1 = dLdZ 
    dLdZ2 = dLdZ 
 
    dLdB1 = np.dot(Y1.T, dLdZ1)
    dLdB2 = np.dot(Y2.T, dLdZ2)
 
    dLdY1 = np.dot(dLdZ1, B1.T)
    dLdY2 = np.dot(dLdZ2, B2.T)
 
    dLdA1 = np.dot(X.T, dLdY1 * (1 - Y1 ** 2))
    dLdA2 = np.dot(X.T, dLdY2 * (1 - Y2 ** 2))
 
    # to sense check our results
    dLdB = np.concatenate([dLdB1, dLdB2], axis=0)
    dLdA = np.concatenate([dLdA1, dLdA2], axis=1)
 
    return dLdA, dLdB

If you run the code above, you'll see that the normal and tensor parallel implementation match. From a quick glance at the code, we can spot a few subtle details that need to be carefully throught through. First, note that since Z=Z1+Z2Z = Z_1 + Z_2 the gradients LZ=LZ1=LZ2\frac{\partial L} {\partial Z} = \frac{\partial L} {\partial Z_1} = \frac{\partial L} {\partial Z_2} are equal. For this reason, these gradients need to be computed on the same device. After this, we can compute all the gradients for AA and BB in parallel. Another detail is that to sense check our results we concatenate AiA_i and BiB_i on different axes. This is because AA was split column-wise while BB was split row-wise. From this example, we can see that tensor parallelism needs to be implemented carefully. It's very easy to introduce silent bugs, which is why tensor parallelism is generally more difficult to implement than data parallelism.

Note that this implementation in numpy is only illustrative. The computation in the code I showed is still done on a single device, but the tensor parallel implementations show the logic that needs to be implemented in a distributed setting. Modern autodiff software like PyTorch and JAX takes care of the details of distributing your computation across multiple devices, but it's important to know the underlying logic since you often still need to specify which parts of your computation should be done in parallel and on which devices.

Finally, parallelizing tensors is highly device specific. You would parallelize differently on a 8x2 vs a 4x4 device topology for instance. In large scale training, tensor parallelism is often combined with data parallelism. For instance, with an MxN device grid you might parallelize the data across the rows and tensors across the columns.

A tensor parallel MLP in JAX

Let's see how we can implement a tensor parallel MLP in JAX. We'll use the same example as before, where we have a 2 layer MLP with layer 1 parameterized by matrix A and layer 2 parameterized by matrix B. We'll use the same notation as before, where A1A_1 and A2A_2 are the two equal parts of matrix A, and B1B_1 and B2B_2 are the two equal parts of matrix B. We'll also use the same activation function ff as before.

We'll start by defining a function that splits a matrix into two equal parts column-wise. We'll use the jax.numpy.split (opens in a new tab) function to do this. We'll also define a function that splits a matrix into two equal parts row-wise.

import jax
import jax.numpy as jnp
 
def split_columnwise(A, num_splits):
    return jnp.split(A, num_splits, axis=1)
 
def split_rowwise(A, num_splits):
    return jnp.split(A, num_splits, axis=0)

Now, let's define a function that computes the forward pass of a tensor parallel MLP. Our function will input the data X, the split matrices A_i and B_i (e.g. A_1 and B_1) and the non-linearity and use pmap to split the computation across multiple devices. We'll use the jax.lax.pmap (opens in a new tab) function to do this.

Recall how we parallelized data in our previous post.

import jax
import jax.numpy as jnp
 
def linear_layer(x, w):
  return jnp.dot(x, w)
 
n = 16
d = 3
devices = 8
 
 
xs = jnp.array(np.random.rand(n, d))
ws = jnp.array(np.random.rand(d,))
 
x_parts = np.stack(jnp.split(xs, devices))
w_parts = jax.tree_map(lambda x: np.stack([x for _ in range(devices)]), ws)
 
out = jax.pmap(linear_layer)(x_parts, w_parts)
print(out.shape) # (8, 2), out is a matrix of shape (n_devices, n_data // n_devices)

Here, the weights were replicated across devices while data was split and parallelized. Now we want to do the opposite - we want to replicate the data and parallelize the model.

import jax
import jax.numpy as jnp
 
def split_columnwise(A, num_splits):
  return jnp.split(A, num_splits, axis=1)
 
def split_rowwise(A, num_splits):
  return jnp.split(A, num_splits, axis=0)
 
def forward(x, a, b):
  y = jnp.tanh(jnp.dot(x, a))
  z = jnp.dot(y, b)
  return z
 
 
A_parts = np.stack(split_columnwise(A, devices))
B_parts = np.stack(split_rowwise(B, devices))
X_parts = jax.tree_map(lambda x: np.stack([x for _ in range(devices)]), X)
 
out = jax.pmap(forward)(X_parts, A_parts, B_parts)
 
z = jnp.sum(out, axis=0)

Given the same inputs for X,A,B as in the numpy example, we get the same output. Except this time JAX has parallelized the computation across two devices. We can also check that the gradients are correct.

Cross device communication

In the numpy and jax examples, we parallelized the MLP computation across two devices. However, we only addressed the parallelization logic and did not discuss how data is centralized and distributed across devices. For example, at the end of the forward pass we need to sum the outputs of the two devices Z=Z1+Z2Z = Z_1 + Z_2, but Z1Z_1 and Z2Z_2 live on different devices. How do we perform the summation?

In distributed computing, this can be done with an AllReduce operation which performs a reduction (e.g. summation), then processes and distributes the result to all devices. In JAX, we can use the jax.lax.psum (opens in a new tab) function to perform an AllReduce operation.

To compute the loss, we also need to use an Gather operation for this term concatenate([Y1, Y2]). Finally, to distribute the gradients back to the devices, we need to use an Scatter operation.

Megatron

What we've outlined in this post is a form of tensor parallelism that was introduced in the Megatron paper (opens in a new tab) by NVIDIA. Unlike pipeline parallelism, Megatron tensor parallelism is efficient in the sense that there is minimal idle time on the devices, and for this reason it is commonly used for large scale training of language models.