Lightning Thunder makes PyTorch programs faster both on single accelerators or in distributed settings.
Thunder aims to be usable, understandable, and extensible and can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors.
This extension directory shows how Thunder can be used with LitGPT.
Warning
This document is an early-access development version that is currently only for internal use. We recommend users checking out the Lightning Thunder project directly, which provides more up-to-date usage information.
To try Lightning Thunder with your model simply thunder.jit()
it.
from litgpt import GPT
import thunder
import torch
# Use only two layers to keep the traces shorter for the demonstration
model = GPT.from_name("Llama-2-7b-hf", n_layer=2).cuda()
model = thunder.jit(model)
x = torch.randint(model.max_seq_length, (2, 5), device="cuda")
y = model(x) # forward, this may take a bit
This will require some compilation time on the first forward call.
The JIT is will acquire a Python program (what we call a "trace") from the Python program (GPT
, a torch.nn.Module
in this example) that was given.
This process targets PyTorch operators (like Tensor.view()
, +
, torch.nn.functional.scaled_dot_product_atttention()
) and optionally custom operators (more about that later).
We can visualize the thunder trace generated under the hood:
forward_trace = thunder.last_traces(model)[-1].python()
print(forward_trace)
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
# args: "Collection"
t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, \
t18, t19, = args
del args
t24 = torch.nn.functional.embedding(t0, t19, None, None, 2.0, False, False) # t24: "cuda:0 f32[2, 5, 4096]"
t20 = torch_slice_prim_impl(t1, [0, 0], [5, 128], [1, 1]) # t20: "cuda:0 f32[5, 128]"
t21 = torch_slice_prim_impl(t2, [0, 0], [5, 128], [1, 1]) # t21: "cuda:0 f32[5, 128]"
t200 = torch.unsqueeze(t11, 0) # t200: "cuda:0 f32[1, 4096]"
t201 = torch.unsqueeze(t200, 1) # t201: "cuda:0 f32[1, 1, 4096]"
del t200
t33 = Tensor.expand(t201, (2, 5, 4096)) # t33: "cuda:0 f32[2, 5, 4096]"
del t201
t229 = torch.unsqueeze(t13, 0) # t229: "cuda:0 f32[1, 4096]"
t230 = torch.unsqueeze(t229, 1) # t230: "cuda:0 f32[1, 1, 4096]"
del t229
t84 = Tensor.expand(t230, (2, 5, 4096)) # t84: "cuda:0 f32[2, 5, 4096]"
del t230
t232 = torch.unsqueeze(t12, 0) # t232: "cuda:0 f32[1, 4096]"
t233 = torch.unsqueeze(t232, 1) # t233: "cuda:0 f32[1, 1, 4096]"
del t232
t104 = Tensor.expand(t233, (2, 5, 4096)) # t104: "cuda:0 f32[2, 5, 4096]"
del t233
t253 = torch.unsqueeze(t14, 0) # t253: "cuda:0 f32[1, 4096]"
t254 = torch.unsqueeze(t253, 1) # t254: "cuda:0 f32[1, 1, 4096]"
del t253
t155 = Tensor.expand(t254, (2, 5, 4096)) # t155: "cuda:0 f32[2, 5, 4096]"
del t254
t256 = torch.unsqueeze(t10, 0) # t256: "cuda:0 f32[1, 4096]"
t257 = torch.unsqueeze(t256, 1) # t257: "cuda:0 f32[1, 1, 4096]"
del t256
t175 = Tensor.expand(t257, (2, 5, 4096)) # t175: "cuda:0 f32[2, 5, 4096]"
del t257
t221 = torch.unsqueeze(t20, 0) # t221: "cuda:0 f32[1, 5, 128]"
del t20
t222 = torch.unsqueeze(t221, 1) # t222: "cuda:0 f32[1, 1, 5, 128]"
del t221
t49 = Tensor.expand(t222, (2, 32, 5, 128)) # t49: "cuda:0 f32[2, 32, 5, 128]"
del t222
t224 = torch.unsqueeze(t21, 0) # t224: "cuda:0 f32[1, 5, 128]"
del t21
t225 = torch.unsqueeze(t224, 1) # t225: "cuda:0 f32[1, 1, 5, 128]"
del t224
t51 = Tensor.expand(t225, (2, 32, 5, 128)) # t51: "cuda:0 f32[2, 32, 5, 128]"
del t225
[t30, t34] = nvFusion0(t24, t33)
t35 = torch.nn.functional.linear(t34, t3, None) # t35: "cuda:0 f32[2, 5, 12288]"
t36 = torch.reshape(t35, (2, 5, 32, 3, 128)) # t36: "cuda:0 f32[2, 5, 32, 3, 128]"
del t35
t37 = torch.permute(t36, (0, 2, 3, 1, 4)) # t37: "cuda:0 f32[2, 32, 3, 5, 128]"
del t36
(t38, t39, t40) = torch.split(t37, (1, 1, 1), 2)
del t37
t41 = torch.reshape(t38, (2, 32, 5, 128)) # t41: "cuda:0 f32[2, 32, 5, 128]"
del t38
t42 = torch.reshape(t39, (2, 32, 5, 128)) # t42: "cuda:0 f32[2, 32, 5, 128]"
del t39
t43 = torch.reshape(t40, (2, 32, 5, 128)) # t43: "cuda:0 f32[2, 32, 5, 128]"
del t40
t44 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t44: "cuda:0 f32[2, 32, 5, 128]"
t54 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t54: "cuda:0 f32[2, 32, 5, 128]"
t64 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t64: "cuda:0 f32[2, 32, 5, 0]"
del t41
t66 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t66: "cuda:0 f32[2, 32, 5, 0]"
del t42
t46 = torch_slice_prim_impl(t44, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t46: "cuda:0 f32[2, 32, 5, 64]"
t45 = torch_slice_prim_impl(t44, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t45: "cuda:0 f32[2, 32, 5, 64]"
t55 = torch_slice_prim_impl(t54, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t55: "cuda:0 f32[2, 32, 5, 64]"
t56 = torch_slice_prim_impl(t54, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t56: "cuda:0 f32[2, 32, 5, 64]"
[t47, t57] = nvFusion1(t46, t56)
del t46, t56
t48 = torch.cat((t47, t45), -1) # t48: "cuda:0 f32[2, 32, 5, 128]"
del t47, t45
t58 = torch.cat((t57, t55), -1) # t58: "cuda:0 f32[2, 32, 5, 128]"
del t57, t55
[t53, t63] = nvFusion2(t44, t48, t49, t51, t54, t58)
del t44, t48, t54, t58
t65 = torch.cat((t53, t64), -1) # t65: "cuda:0 f32[2, 32, 5, 128]"
del t53, t64
t67 = torch.cat((t63, t66), -1) # t67: "cuda:0 f32[2, 32, 5, 128]"
del t63, t66
(t68, t69, t70, t71) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t65, t67, t43, None, 0.0, True, 0.08838834764831843)
t72 = torch.permute(t68, (0, 2, 1, 3)) # t72: "cuda:0 f32[2, 5, 32, 128]"
t73 = torch.reshape(t72, (2, 5, 4096)) # t73: "cuda:0 f32[2, 5, 4096]"
del t72
t74 = torch.nn.functional.linear(t73, t15, None) # t74: "cuda:0 f32[2, 5, 4096]"
[t75, t81, t85] = nvFusion3(t24, t74, t84)
del t74
t86 = torch.nn.functional.linear(t85, t5, None) # t86: "cuda:0 f32[2, 5, 11008]"
t87 = torch.nn.functional.linear(t85, t7, None) # t87: "cuda:0 f32[2, 5, 11008]"
[t93] = nvFusion4(t86, t87)
t94 = torch.nn.functional.linear(t93, t16, None) # t94: "cuda:0 f32[2, 5, 4096]"
[t101, t105, t95] = nvFusion5(t104, t75, t94)
del t94
t106 = torch.nn.functional.linear(t105, t4, None) # t106: "cuda:0 f32[2, 5, 12288]"
t107 = torch.reshape(t106, (2, 5, 32, 3, 128)) # t107: "cuda:0 f32[2, 5, 32, 3, 128]"
del t106
t108 = torch.permute(t107, (0, 2, 3, 1, 4)) # t108: "cuda:0 f32[2, 32, 3, 5, 128]"
del t107
(t109, t110, t111) = torch.split(t108, (1, 1, 1), 2)
del t108
t112 = torch.reshape(t109, (2, 32, 5, 128)) # t112: "cuda:0 f32[2, 32, 5, 128]"
del t109
t113 = torch.reshape(t110, (2, 32, 5, 128)) # t113: "cuda:0 f32[2, 32, 5, 128]"
del t110
t114 = torch.reshape(t111, (2, 32, 5, 128)) # t114: "cuda:0 f32[2, 32, 5, 128]"
del t111
t135 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t135: "cuda:0 f32[2, 32, 5, 0]"
t137 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t137: "cuda:0 f32[2, 32, 5, 0]"
t115 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t115: "cuda:0 f32[2, 32, 5, 128]"
del t112
t125 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t125: "cuda:0 f32[2, 32, 5, 128]"
del t113
t116 = torch_slice_prim_impl(t115, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t116: "cuda:0 f32[2, 32, 5, 64]"
t117 = torch_slice_prim_impl(t115, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t117: "cuda:0 f32[2, 32, 5, 64]"
t127 = torch_slice_prim_impl(t125, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t127: "cuda:0 f32[2, 32, 5, 64]"
t126 = torch_slice_prim_impl(t125, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t126: "cuda:0 f32[2, 32, 5, 64]"
[t118, t128] = nvFusion6(t117, t127)
del t117, t127
t129 = torch.cat((t128, t126), -1) # t129: "cuda:0 f32[2, 32, 5, 128]"
del t128, t126
t119 = torch.cat((t118, t116), -1) # t119: "cuda:0 f32[2, 32, 5, 128]"
del t118, t116
[t124, t134] = nvFusion7(t115, t119, t125, t129, t49, t51)
del t115, t119, t125, t129
t136 = torch.cat((t124, t135), -1) # t136: "cuda:0 f32[2, 32, 5, 128]"
del t124, t135
t138 = torch.cat((t134, t137), -1) # t138: "cuda:0 f32[2, 32, 5, 128]"
del t134, t137
(t139, t140, t141, t142) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t136, t138, t114, None, 0.0, True, 0.08838834764831843)
t143 = torch.permute(t139, (0, 2, 1, 3)) # t143: "cuda:0 f32[2, 5, 32, 128]"
t144 = torch.reshape(t143, (2, 5, 4096)) # t144: "cuda:0 f32[2, 5, 4096]"
del t143
t145 = torch.nn.functional.linear(t144, t17, None) # t145: "cuda:0 f32[2, 5, 4096]"
[t146, t152, t156] = nvFusion8(t145, t155, t95)
del t145
t158 = torch.nn.functional.linear(t156, t8, None) # t158: "cuda:0 f32[2, 5, 11008]"
t157 = torch.nn.functional.linear(t156, t6, None) # t157: "cuda:0 f32[2, 5, 11008]"
[t164] = nvFusion9(t157, t158)
t165 = torch.nn.functional.linear(t164, t18, None) # t165: "cuda:0 f32[2, 5, 4096]"
[t166, t172, t176] = nvFusion10(t146, t165, t175)
del t165
t177 = torch.nn.functional.linear(t176, t9, None) # t177: "cuda:0 f32[2, 5, 32000]"
return {'output': t177, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19], 'flat_output': (t177,)}, ((t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95), (False, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2))
This is a straight-lined version of GPT.forward
that has been optimized. Since it's running on CUDA, the NvFuser executor has created regions (look for "nvFusion") that fuse multiple operators together.
Operator fusion is very desirable with modern hardware and helps out in overhead-bound or device-bound settings by:
- Launching less kernels, thus reducing the kernel launch overhead.
- Reducing the number of memory accesses performed by reusing them in a fused operation
- Minimizing host-device communications
Thunder also uses a multi-level intermediate representation. If we let it print all levels
forward_trace = thunder.last_traces(model)[-1]
print(forward_trace)
We can see as comments the primitives that compose the fusion regions. For instance, this is the region associated to the RMSNorm
implementation
[t146, t152, t156] = nvFusion8(t145, t155, t95)
# t146 = prims.add(t145, t95) # t146: "cuda:0 f32[2, 5, 4096]"
# t147 = prims.mul(t146, t146) # t147: "cuda:0 f32[2, 5, 4096]"
# t148 = prims.sum(t147, (2,)) # t148: "cuda:0 f32[2, 5]"
# t149 = prims.broadcast_in_dim(t148, [2, 5, 1], [0, 1]) # t149: "cuda:0 f32[2, 5, 1]"
# t150 = prims.div(t149, 4096.0) # t150: "cuda:0 f32[2, 5, 1]"
# t151 = prims.add(t150, 1e-05) # t151: "cuda:0 f32[2, 5, 1]"
# t152 = prims.rsqrt(t151) # t152: "cuda:0 f32[2, 5, 1]"
# t153 = prims.broadcast_in_dim(t152, (2, 5, 4096), (0, 1, 2)) # t153: "cuda:0 f32[2, 5, 4096]"
# t154 = prims.mul(t146, t153) # t154: "cuda:0 f32[2, 5, 4096]"
# t156 = prims.mul(t154, t155) # t156: "cuda:0 f32[2, 5, 4096]"
Similarly, we can visualize the backward trace:
backward_trace = thunder.last_backward_traces(model)[-1].python()
print(backward_trace)
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t178, = cotangents
clear_collection(cotangents)
del cotangents
t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, \
t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, \
t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, \
t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95, = C0
clear_collection(C0)
del C0
b1, b2, b41, b91, f101, f106, f40, f42, f51, f56, f6, f90, f92, i0, i23, i73, \
= C1
clear_collection(C1)
del C1
t639 = torch.reshape(t178, (-1, 32000)) # t639: "cuda:0 f32[10, 32000]"
del t178
t643 = torch.permute(t639, (1, 0)) # t643: "cuda:0 f32[32000, 10]"
t644 = torch.reshape(t176, (-1, 4096)) # t644: "cuda:0 f32[10, 4096]"
del t176
t669 = torch.reshape(t164, (-1, 11008)) # t669: "cuda:0 f32[10, 11008]"
del t164
t686 = torch.reshape(t156, (-1, 4096)) # t686: "cuda:0 f32[10, 4096]"
del t156
t720 = torch.reshape(t144, (-1, 4096)) # t720: "cuda:0 f32[10, 4096]"
del t144
t776 = torch.reshape(t105, (-1, 4096)) # t776: "cuda:0 f32[10, 4096]"
del t105
t802 = torch.reshape(t93, (-1, 11008)) # t802: "cuda:0 f32[10, 11008]"
del t93
t819 = torch.reshape(t85, (-1, 4096)) # t819: "cuda:0 f32[10, 4096]"
del t85
t853 = torch.reshape(t73, (-1, 4096)) # t853: "cuda:0 f32[10, 4096]"
del t73
t911 = torch.reshape(t34, (-1, 4096)) # t911: "cuda:0 f32[10, 4096]"
del t34
t640 = torch.matmul(t639, t9) # t640: "cuda:0 f32[10, 4096]"
del t639, t9
t645 = torch.matmul(t643, t644) # t645: "cuda:0 f32[32000, 4096]"
del t643, t644
t641 = torch.reshape(t640, (2, 5, 4096)) # t641: "cuda:0 f32[2, 5, 4096]"
del t640
[t648, t663] = nvFusion0(f106, t166, t172, t175, t641)
del f106, t166, t172, t175, t641
t664 = torch.reshape(t663, (-1, 4096)) # t664: "cuda:0 f32[10, 4096]"
t668 = torch.permute(t664, (1, 0)) # t668: "cuda:0 f32[4096, 10]"
t665 = torch.matmul(t664, t18) # t665: "cuda:0 f32[10, 11008]"
del t664, t18
t670 = torch.matmul(t668, t669) # t670: "cuda:0 f32[4096, 11008]"
del t668, t669
t666 = torch.reshape(t665, (2, 5, 11008)) # t666: "cuda:0 f32[2, 5, 11008]"
del t665
[t672, t680] = nvFusion1(t157, t158, t666)
del t157, t158, t666
t681 = torch.reshape(t672, (-1, 11008)) # t681: "cuda:0 f32[10, 11008]"
del t672
t685 = torch.permute(t681, (1, 0)) # t685: "cuda:0 f32[11008, 10]"
t688 = torch.reshape(t680, (-1, 11008)) # t688: "cuda:0 f32[10, 11008]"
del t680
t692 = torch.permute(t688, (1, 0)) # t692: "cuda:0 f32[11008, 10]"
t689 = torch.matmul(t688, t6) # t689: "cuda:0 f32[10, 4096]"
del t688, t6
t682 = torch.matmul(t681, t8) # t682: "cuda:0 f32[10, 4096]"
del t681, t8
t694 = torch.matmul(t692, t686) # t694: "cuda:0 f32[11008, 4096]"
del t692
t687 = torch.matmul(t685, t686) # t687: "cuda:0 f32[11008, 4096]"
del t685, t686
t683 = torch.reshape(t682, (2, 5, 4096)) # t683: "cuda:0 f32[2, 5, 4096]"
del t682
t690 = torch.reshape(t689, (2, 5, 4096)) # t690: "cuda:0 f32[2, 5, 4096]"
del t689
[t698, t714] = nvFusion2(f101, t146, t152, t155, t663, t683, t690)
del f101, t146, t152, t155, t663, t683, t690
t715 = torch.reshape(t714, (-1, 4096)) # t715: "cuda:0 f32[10, 4096]"
t719 = torch.permute(t715, (1, 0)) # t719: "cuda:0 f32[4096, 10]"
t716 = torch.matmul(t715, t17) # t716: "cuda:0 f32[10, 4096]"
del t715, t17
t721 = torch.matmul(t719, t720) # t721: "cuda:0 f32[4096, 4096]"
del t719, t720
t717 = torch.reshape(t716, (2, 5, 4096)) # t717: "cuda:0 f32[2, 5, 4096]"
del t716
t722 = torch.reshape(t717, (2, 5, 32, 128)) # t722: "cuda:0 f32[2, 5, 32, 128]"
del t717
t723 = torch.permute(t722, (0, 2, 1, 3)) # t723: "cuda:0 f32[2, 32, 5, 128]"
del t722
(t724, t725, t726, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t723, t136, t138, t114, None, t139, t140, t141, t142, f90, b91, scale=f92)
del t723, t136, t138, t114, t139, t140, t141, t142, f90, b91, f92
t765 = torch.reshape(t726, (2, 32, 1, 5, 128)) # t765: "cuda:0 f32[2, 32, 1, 5, 128]"
del t726
t727 = torch_slice_prim_impl(t725, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t727: "cuda:0 f32[2, 32, 5, 128]"
del t725
t730 = torch_slice_prim_impl(t724, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t730: "cuda:0 f32[2, 32, 5, 128]"
del t724
[t747, t764] = nvFusion3(t49, t51, t727, t730)
del t727, t730
t766 = torch.reshape(t747, (2, 32, 1, 5, 128)) # t766: "cuda:0 f32[2, 32, 1, 5, 128]"
del t747
t767 = torch.reshape(t764, (2, 32, 1, 5, 128)) # t767: "cuda:0 f32[2, 32, 1, 5, 128]"
del t764
t768 = torch.cat((t767, t766, t765), i73) # t768: "cuda:0 f32[2, 32, 3, 5, 128]"
del t767, t766, t765, i73
t769 = torch.permute(t768, (0, 3, 1, 2, 4)) # t769: "cuda:0 f32[2, 5, 32, 3, 128]"
del t768
t770 = torch.reshape(t769, (2, 5, 12288)) # t770: "cuda:0 f32[2, 5, 12288]"
del t769
t771 = torch.reshape(t770, (-1, 12288)) # t771: "cuda:0 f32[10, 12288]"
del t770
t775 = torch.permute(t771, (1, 0)) # t775: "cuda:0 f32[12288, 10]"
t777 = torch.matmul(t775, t776) # t777: "cuda:0 f32[12288, 4096]"
del t775, t776
t772 = torch.matmul(t771, t4) # t772: "cuda:0 f32[10, 4096]"
del t771, t4
t773 = torch.reshape(t772, (2, 5, 4096)) # t773: "cuda:0 f32[2, 5, 4096]"
del t772
[t780, t796] = nvFusion4(f56, t101, t104, t714, t773, t95)
del f56, t101, t104, t714, t773, t95
t797 = torch.reshape(t796, (-1, 4096)) # t797: "cuda:0 f32[10, 4096]"
t801 = torch.permute(t797, (1, 0)) # t801: "cuda:0 f32[4096, 10]"
t798 = torch.matmul(t797, t16) # t798: "cuda:0 f32[10, 11008]"
del t797, t16
t803 = torch.matmul(t801, t802) # t803: "cuda:0 f32[4096, 11008]"
del t801, t802
t799 = torch.reshape(t798, (2, 5, 11008)) # t799: "cuda:0 f32[2, 5, 11008]"
del t798
[t805, t813] = nvFusion5(t799, t86, t87)
del t799, t86, t87
t814 = torch.reshape(t805, (-1, 11008)) # t814: "cuda:0 f32[10, 11008]"
del t805
t818 = torch.permute(t814, (1, 0)) # t818: "cuda:0 f32[11008, 10]"
t821 = torch.reshape(t813, (-1, 11008)) # t821: "cuda:0 f32[10, 11008]"
del t813
t825 = torch.permute(t821, (1, 0)) # t825: "cuda:0 f32[11008, 10]"
t822 = torch.matmul(t821, t5) # t822: "cuda:0 f32[10, 4096]"
del t821, t5
t815 = torch.matmul(t814, t7) # t815: "cuda:0 f32[10, 4096]"
del t814, t7
t827 = torch.matmul(t825, t819) # t827: "cuda:0 f32[11008, 4096]"
del t825
t820 = torch.matmul(t818, t819) # t820: "cuda:0 f32[11008, 4096]"
del t818, t819
t816 = torch.reshape(t815, (2, 5, 4096)) # t816: "cuda:0 f32[2, 5, 4096]"
del t815
t823 = torch.reshape(t822, (2, 5, 4096)) # t823: "cuda:0 f32[2, 5, 4096]"
del t822
[t831, t847] = nvFusion6(f51, t75, t796, t81, t816, t823, t84)
del f51, t75, t796, t81, t816, t823, t84
t848 = torch.reshape(t847, (-1, 4096)) # t848: "cuda:0 f32[10, 4096]"
t852 = torch.permute(t848, (1, 0)) # t852: "cuda:0 f32[4096, 10]"
t849 = torch.matmul(t848, t15) # t849: "cuda:0 f32[10, 4096]"
del t848, t15
t854 = torch.matmul(t852, t853) # t854: "cuda:0 f32[4096, 4096]"
del t852, t853
t850 = torch.reshape(t849, (2, 5, 4096)) # t850: "cuda:0 f32[2, 5, 4096]"
del t849
t855 = torch.reshape(t850, (2, 5, 32, 128)) # t855: "cuda:0 f32[2, 5, 32, 128]"
del t850
t856 = torch.permute(t855, (0, 2, 1, 3)) # t856: "cuda:0 f32[2, 32, 5, 128]"
del t855
(t857, t858, t859, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t856, t65, t67, t43, None, t68, t69, t70, t71, f40, b41, scale=f42)
del t856, t65, t67, t43, t68, t69, t70, t71, f40, b41, f42
t900 = torch.reshape(t859, (2, 32, 1, 5, 128)) # t900: "cuda:0 f32[2, 32, 1, 5, 128]"
del t859
t863 = torch_slice_prim_impl(t857, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t863: "cuda:0 f32[2, 32, 5, 128]"
del t857
t860 = torch_slice_prim_impl(t858, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t860: "cuda:0 f32[2, 32, 5, 128]"
del t858
[t882, t899] = nvFusion7(t49, t51, t860, t863)
del t49, t51, t860, t863
t902 = torch.reshape(t899, (2, 32, 1, 5, 128)) # t902: "cuda:0 f32[2, 32, 1, 5, 128]"
del t899
t901 = torch.reshape(t882, (2, 32, 1, 5, 128)) # t901: "cuda:0 f32[2, 32, 1, 5, 128]"
del t882
t903 = torch.cat((t902, t901, t900), i23) # t903: "cuda:0 f32[2, 32, 3, 5, 128]"
del t902, t901, t900, i23
t904 = torch.permute(t903, (0, 3, 1, 2, 4)) # t904: "cuda:0 f32[2, 5, 32, 3, 128]"
del t903
t905 = torch.reshape(t904, (2, 5, 12288)) # t905: "cuda:0 f32[2, 5, 12288]"
del t904
t906 = torch.reshape(t905, (-1, 12288)) # t906: "cuda:0 f32[10, 12288]"
del t905
t910 = torch.permute(t906, (1, 0)) # t910: "cuda:0 f32[12288, 10]"
t907 = torch.matmul(t906, t3) # t907: "cuda:0 f32[10, 4096]"
del t906, t3
t912 = torch.matmul(t910, t911) # t912: "cuda:0 f32[12288, 4096]"
del t910, t911
t908 = torch.reshape(t907, (2, 5, 4096)) # t908: "cuda:0 f32[2, 5, 4096]"
del t907
[t915, t931] = nvFusion8(f6, t24, t30, t33, t847, t908)
del f6, t24, t30, t33, t847, t908
t932 = torch.torch.ops.aten.embedding_backward(t931, t0, i0, -1, b1, b2) # t932: "cuda:0 f32[32000, 4096]"
del t931, t0, i0, b1, b2
return (None, None, None, t912, t777, t827, t694, t820, t687, t645, t648, t915, t780, t831, t698, t854, t803, t721, t670, t932)
These traces are long, and require some familiarity with the model implementation to follow them, but they allow you to:
- Inspect exactly what operations are run including their decompositions.
- Inspect the sizes of tensors, their device, data type and conversions.
- Apply transformations to the traces since the computations are completely decoupled from the data.
- Inspect the backward operations generated for each forward operation to understand what autograd is doing.
Transforms are one of the core features of Thunder. For example, they enable easy data parallel distribution. That is replicated data parallelism (DDP) and fully-sharded data parallelism (FSDP).
We provide ready-to-use Fabric strategies that integrate Thunder DDP|FSDP. Under the hood, the code is quite straightforward:
model = thunder.distributed.ddp(model)
# or
# model = thunder.distributed.fsdp(model)
model = thunder.jit(model)
After applying the DDP transformation, the backward trace will include the expected all-reduce collectives:
p1022 = torch_all_reduce_prim_impl(t1021, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True, False) # p1022: "FUTURE cuda:0 f32[16797696]"
...
t1059 = torch_wait_prim_impl(p1025) # t1059: "cuda:0 f32[131072000]"
With L.Fabric
, this is how to use them:
from extensions.thunder.strategies import ThunderFSDPStrategy, ThunderDDPStrategy
# fully-sharded data parallel
strategy = ThunderFSDPStrategy(
sharding_strategy="ZERO3",
bucketing_strategy="BLOCK",
executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"),
state_dict_type="full",
)
# replicated data parallel
strategy = ThunderDDPStrategy(executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"))
fabric = L.Fabric(devices=devices, strategy=strategy)
fabric.launch()
model = fabric.setup(model) # JIT is called here
And in the case of FSDP all-gathers in forward and reduce-scatters in backward. Meaning that Thunder automatically introduced the necessary collective operations to support data parallelism.
Thunder allows you to define a priority list of executors that can map operators:
import thunder
model = thunder.jit(
model,
executors=["sdpa", "torchcompile_cat", "nvfuser", "torch"]
)
Notice how torch.compile
is a valid executor. This executor registers a few operators with improved performance so that you can utilize the fastest set of operator implementations possible.
Lightning Thunder provides extension points to integrate fast kernels for operators in your model without having to modify your implementation.
For instance, the Unsloth project provides several Triton kernels that can be used with LitGPT:
- Cross entropy loss
- SwiGLU (part of
LLaMAMLP
) - RoPE
The unsloth
directory contains a custom executor that registers these operators for LitGPT.
We can enable this executor by passing it to the list of executors available. The order matters because we want to run its custom operators before
NvFuser
creates its fusion regions.
import thunder
model = thunder.jit(
model,
executors=["sdpa", "unsloth", "torchcompile_cat", "nvfuser", "torch"]
)
Doing this, the model trace now includes the Unsloth kernel calls:
def augmented_forward_fn(*args):
...
(t121, _, _, _, _, _) = unsloth_apply_rope(t120, t21, t22)
...
(t189, t190) = unsloth_cross_entropy(t187, t188)
...
def backward_fn(saved_for_backward, cotangents):
...
t652 = unsloth_cross_entropy_backward(t651, t187, t188, t190) # t652: "cuda:0 f32[6, 320]"
...
t763 = unsloth_apply_rope_backward(t757, t21, t22, 1, 8, 4) # t763: "cuda:0 f32[2, 4, 3, 16]"
We provide a specific pre-training script copy that uses this executor. Given the Unsloth results below, these hand-written kernels do not seem to be worth it, showcasing the power of automated fusion compilers like NvFuser.
Warning
Lightning Thunder is alpha and not ready for production runs. Feel free to try it out, expect a few bumps along the way. We expect speed and memory usage to improve as we continue to develop it.
We provide a version of the main pre-training script that integrates Thunder that uses TinyLlama, a 1.1B parameter LLM.
Setting | Compiler | Executors | Devices | ms/iter @ step 10 | Memory (GB) |
---|---|---|---|---|---|
Fully-sharded ZeRO 3 | Eager | - | 8 | 456.57 | 22.13 |
Fully-sharded ZeRO 3 | torch | - | 8 | Not supported | Not supported |
Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile | 8 | Not supported | Not supported |
Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 333.56 | 21.40 |
Replicated | Eager | - | 8 | 569.46 | 32.04 |
Replicated | torch | - | 8 | Not supported | Not supported |
Replicated | Thunder | sdpa, torchcompile | 8 | 426.44 | 22.19 |
Replicated | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 356.01 | 27.42 |
- | Eager | - | 1 | 447.65 | 29.84 |
- | torch | - | 1 | Not supported | Not supported |
- | Thunder | sdpa, torchcompile | 1 | 373.37 | 22.19 |
- | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 322.25 | 27.42 |
Unsloth | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 331.92 | 25.19 |
Reproduction details
Config:
out_dir: out/pretrain-thunder
data: TinyStories
tokenizer_dir: checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0
logger_name: csv
Commands:
litgpt download --repo_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tokenizer_only true
python extensions/thunder/pretrain.py --config config.yaml --compiler null --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --compiler null --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile_cat, nvfuser, torch]' --devices 1
--compiler torch
(torch.compile
without thunder
) is not include because it does not support compiling the _FabricModule
due to this issue: pytorch/pytorch#112787 (comment)
The CUDA devices are all NVIDIA A100-SXM4-40GB.
Python version: 3.10.12 [GCC 11.4.0] (64-bit runtime)
Is debug build: False
CUDA used to build PyTorch: 12.1
CUDA runtime version: 12.3.107
Nvidia driver version: 545.23.08
pytorch-triton==3.0.0+45fff310c8
torch==2.4.0.dev20240427+cu121
lightning==2.3.0.dev20240328
lightning-thunder==0.2.0.dev20240505
nvfuser_cu121==0.2.3.dev20240428