|
| 1 | +Distributed Optimizer with TorchScript support |
| 2 | +============================================================== |
| 3 | + |
| 4 | +.. note:: Distributed Optimizer with TorchScript support is introduced in PyTorch 1.8 |
| 5 | + as a beta feature. This API is subject to change. |
| 6 | + |
| 7 | +In this recipe, you will learn: |
| 8 | + |
| 9 | +- The high-level idea of distributed optimizer with TorchScript support and what this feature brings |
| 10 | +- How to write customized distributed optimizer that enables TorchScript support |
| 11 | + |
| 12 | + |
| 13 | +Requirements |
| 14 | +------------ |
| 15 | + |
| 16 | +- PyTorch 1.8+ |
| 17 | +- `Getting Started With Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`_ |
| 18 | + |
| 19 | + |
| 20 | +What is Distributed Optimizer? |
| 21 | +------------------------------------ |
| 22 | + |
| 23 | +`DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`_ takes a list of remote |
| 24 | +parameters (RRef) and runs the optimizer locally on the workers where the parameters live, which is commonly used together |
| 25 | +with Distributed RPC/Autograd to do model parallel training. It could use any of the local optimizer algorithms (either |
| 26 | +pre-defined algorithms provided in ``torch.optim`` or custom defined ones) to apply the gradients on each worker. |
| 27 | + |
| 28 | + |
| 29 | +What is Distributed Optimizer with TorchScript support? |
| 30 | +------------------------------------------------------- |
| 31 | + |
| 32 | +Distributed Optimizer are widely used in distributed model parallel training, and in some |
| 33 | +common use cases, training need to be done in multithreaded manner instead of multiprocess |
| 34 | +due to performance concern and resource utilizations (or at least partially multithreaded, |
| 35 | +i.e. Parameter Server hosting part of the model and parameters, with new thread updating the |
| 36 | +parameters per request). PyTorch itself does not support multithreaded training natively as |
| 37 | +it suffers from the Python's Global Interpreter Lock (GIL), but it could leverage |
| 38 | +`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ to get rid of GIL and run the |
| 39 | +model in a multithreaded way. |
| 40 | + |
| 41 | +For critical model training workloads, improving the training performance is an |
| 42 | +important topic. Researchers often would like to implement different optimization strategies |
| 43 | +with the graph representation (i.e. via operator fusion) or implement custom operator kernels |
| 44 | +in order to speed up training. |
| 45 | + |
| 46 | +Distributed Optimizer with TorchScript support could help getting rid of GIL, thus improve |
| 47 | +PyTorch's training performance in the multithreaded environment, it also unlocks the potential |
| 48 | +to further enhance the performance by using advanced compiler technologies that TorchScript |
| 49 | +offers (i.e. CPU/GPU fusion). |
| 50 | + |
| 51 | + |
| 52 | +How to write a customized distributed optimizer with TorchScript support? |
| 53 | +------------------------------------------------------------------------- |
| 54 | + |
| 55 | +The code below shows how to write a customized distributed optimizer given an existing local |
| 56 | +optimizer implementation, which unlocks the TorchScript benefits including GIL removal and |
| 57 | +performance improvement opportunities. |
| 58 | + |
| 59 | +Suppose that you already have a local optimizer that is currently used during training, |
| 60 | +In this case we will use `quasi-hyperbolic momentum (QHM) <https://github.com/facebookresearch/qhoptim/blob/e81dea3f2765780cf4fbb90b87b22ba7604b8625/qhoptim/pyt/qhm.py#L12>`_ |
| 61 | +as an example to show how to enable the TorchScript support, note that it also applies |
| 62 | +to any custom optimizers that inherits from ``torch.optim.Optimizer``. |
| 63 | + |
| 64 | +First, we need to separate the computation and state management from the optimizer implementation, |
| 65 | +this is so that we could extract the computation part and make it a free function, which is |
| 66 | +TorchScript friendly. It has two benefits: 1. The computation logic becomes easier to inspect, |
| 67 | +it allows us to quickly turn the parameter update/computation part into TorchScript, and utilize |
| 68 | +TorchScript IR to do further optimizations (operator fusion, etc.) 2. Distributed Optimizer |
| 69 | +underlying is using a different mechanisms to get gradients and update parameters (we store |
| 70 | +gradients separately instead of directly populating the ``param.grad`` field during backward). |
| 71 | +Separating the computation allows distributed optimizer to enable the possibility of optimizer |
| 72 | +update in multithreaded mode, as it eliminates the possible race condition to ``param.grad``. |
| 73 | + |
| 74 | + |
| 75 | +:: |
| 76 | + |
| 77 | + import torch |
| 78 | + from torch import Tensor |
| 79 | + from typing import List |
| 80 | + |
| 81 | + |
| 82 | + def qhm_update(params: List[Tensor], |
| 83 | + dp_list: List[Tensor], |
| 84 | + momentum_buffer_list: List[Tensor], |
| 85 | + lr: float, |
| 86 | + nu: float, |
| 87 | + weight_decay: float, |
| 88 | + weight_decay_type: str, |
| 89 | + momentum: float): |
| 90 | + |
| 91 | + for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list): |
| 92 | + if weight_decay != 0: |
| 93 | + if weight_decay_type == "grad": |
| 94 | + d_p.add_(weight_decay, p) |
| 95 | + elif weight_decay_type == "direct": |
| 96 | + p.mul_(1.0 - lr * weight_decay) |
| 97 | + else: |
| 98 | + raise ValueError("Invalid weight decay type provided") |
| 99 | + |
| 100 | + momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p) |
| 101 | + |
| 102 | + p.data.add_(-lr * nu, momentum_buffer) |
| 103 | + p.data.add_(-lr * (1.0 - nu), d_p) |
| 104 | + |
| 105 | + |
| 106 | + |
| 107 | +Next we will define a distributed functional optimizer with TorchScript compatability to manage |
| 108 | +the optimizer states and calls into the TorchScript compatible update function we defined above. |
| 109 | +Note that a few conventions are different from normal custom optimizers: 1. We don't inherit |
| 110 | +``torch.optim.Optimizer`` as TorchScript does not support polymorphism 2. ``step`` takes gradients |
| 111 | +list instead of the loss closure. |
| 112 | + |
| 113 | +:: |
| 114 | + |
| 115 | + import torch |
| 116 | + from torch import Tensor |
| 117 | + from typing import List, Optional, Dict |
| 118 | + |
| 119 | + # define this as a TorchScript class |
| 120 | + @torch.jit.script |
| 121 | + class FunctionalQHM(object): |
| 122 | + def __init__(self, |
| 123 | + params: List[Tensor], |
| 124 | + lr: float, |
| 125 | + momentum: float, |
| 126 | + nu: float, |
| 127 | + weight_decay: float = 0.0, |
| 128 | + weight_decay_type: str = "grad"): |
| 129 | + if lr < 0.0: |
| 130 | + raise ValueError("Invalid learning rate: {}".format(lr)) |
| 131 | + if momentum < 0.0: |
| 132 | + raise ValueError("Invalid momentum value: {}".format(momentum)) |
| 133 | + if weight_decay < 0.0: |
| 134 | + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) |
| 135 | + if weight_decay_type not in ("grad", "direct"): |
| 136 | + raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type)) |
| 137 | + |
| 138 | + self.defaults = { |
| 139 | + "lr": lr, |
| 140 | + "momentum": momentum, |
| 141 | + "nu": nu, |
| 142 | + "weight_decay": weight_decay, |
| 143 | + } |
| 144 | + self.weight_decay_type = weight_decay_type |
| 145 | + |
| 146 | + # NOTE: we only have one param_group here and don't allow user to add additional |
| 147 | + # param group as it's not a common use case. |
| 148 | + self.param_group = {"params": params} |
| 149 | + |
| 150 | + self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) |
| 151 | + |
| 152 | + def step(self, gradients: List[Optional[Tensor]]): |
| 153 | + params = self.param_group['params'] |
| 154 | + params_with_grad = [] |
| 155 | + grads = [] |
| 156 | + momentum_buffer_list: List[Tensor] = [] |
| 157 | + |
| 158 | + if len(params) != len(gradients): |
| 159 | + raise ValueError( |
| 160 | + "the gradients passed in does not equal to the size of the parameters!" |
| 161 | + + f"Params length: {len(params)}. " |
| 162 | + + f"Gradients length: {len(gradients)}" |
| 163 | + ) |
| 164 | + |
| 165 | + for param, gradient in zip(self.param_group['params'], gradients): |
| 166 | + if gradient is not None: |
| 167 | + params_with_grad.append(param) |
| 168 | + grads.append(gradient) |
| 169 | + state = self.state[param] |
| 170 | + state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format) |
| 171 | + momentum_buffer_list.append(state['momentum_buffer']) |
| 172 | + |
| 173 | + # calls into the update function we just defined |
| 174 | + with torch.no_grad(): |
| 175 | + qhm_update(params_with_grad, |
| 176 | + grads, |
| 177 | + momentum_buffer_list, |
| 178 | + self.defaults['lr'], |
| 179 | + self.defaults['nu'], |
| 180 | + self.defaults['weight_decay'], |
| 181 | + self.weight_decay_type, |
| 182 | + self.defaults['momentum']) |
| 183 | + |
| 184 | + |
| 185 | + |
| 186 | +Finally, we register our newly defined distributed functional optimizer into the ``functional_optim_map`` |
| 187 | +This is so that the ``DistributedOptimizer`` will try to pick up our custom implementation instead of the |
| 188 | +pre-defined default ones. |
| 189 | + |
| 190 | +:: |
| 191 | + |
| 192 | + from torch.distributed.optim import DistributedOptimizer |
| 193 | + |
| 194 | + DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM |
| 195 | + |
| 196 | +Now you can use the ``QHM`` optimizer as normal in distributed training by passing it to |
| 197 | +`DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`_ |
| 198 | + |
| 199 | + |
| 200 | +:: |
| 201 | + |
| 202 | + ... |
| 203 | + remote_params_list = [...] |
| 204 | + dist_optim = DistributedOptimizer( |
| 205 | + QHM, remote_params_list, *args, **kwargs |
| 206 | + ) |
| 207 | + |
| 208 | +DistributedOptimizer will automatically transform the QHM optimizer into the ``FunctionalQHM`` under the hood, |
| 209 | +and enable the TorchScript support. This will unlock the performance that boosted by multithreaded training |
| 210 | +and also give more potentials for further improvements (i.e. TorchScript fusion, etc.) |
| 211 | + |
| 212 | +Note that majority of PyTorch built-in optimizers are already using this methodology to speed up distributed |
| 213 | +training. If you see warning about some optimizers haven't been converted yet, you can write your own conversion |
| 214 | +by following this recipe. |
0 commit comments