Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit 061f101

Browse files
wanchaolbrianjoholly1238
authored
Add a recipe for distributed optimizer with TorchScript (#1465)
Co-authored-by: Brian Johnson <brianjo@fb.com> Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
1 parent af68d68 commit 061f101

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
Distributed Optimizer with TorchScript support
2+
==============================================================
3+
4+
.. note:: Distributed Optimizer with TorchScript support is introduced in PyTorch 1.8
5+
as a beta feature. This API is subject to change.
6+
7+
In this recipe, you will learn:
8+
9+
- The high-level idea of distributed optimizer with TorchScript support and what this feature brings
10+
- How to write customized distributed optimizer that enables TorchScript support
11+
12+
13+
Requirements
14+
------------
15+
16+
- PyTorch 1.8+
17+
- `Getting Started With Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`_
18+
19+
20+
What is Distributed Optimizer?
21+
------------------------------------
22+
23+
`DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`_ takes a list of remote
24+
parameters (RRef) and runs the optimizer locally on the workers where the parameters live, which is commonly used together
25+
with Distributed RPC/Autograd to do model parallel training. It could use any of the local optimizer algorithms (either
26+
pre-defined algorithms provided in ``torch.optim`` or custom defined ones) to apply the gradients on each worker.
27+
28+
29+
What is Distributed Optimizer with TorchScript support?
30+
-------------------------------------------------------
31+
32+
Distributed Optimizer are widely used in distributed model parallel training, and in some
33+
common use cases, training need to be done in multithreaded manner instead of multiprocess
34+
due to performance concern and resource utilizations (or at least partially multithreaded,
35+
i.e. Parameter Server hosting part of the model and parameters, with new thread updating the
36+
parameters per request). PyTorch itself does not support multithreaded training natively as
37+
it suffers from the Python's Global Interpreter Lock (GIL), but it could leverage
38+
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ to get rid of GIL and run the
39+
model in a multithreaded way.
40+
41+
For critical model training workloads, improving the training performance is an
42+
important topic. Researchers often would like to implement different optimization strategies
43+
with the graph representation (i.e. via operator fusion) or implement custom operator kernels
44+
in order to speed up training.
45+
46+
Distributed Optimizer with TorchScript support could help getting rid of GIL, thus improve
47+
PyTorch's training performance in the multithreaded environment, it also unlocks the potential
48+
to further enhance the performance by using advanced compiler technologies that TorchScript
49+
offers (i.e. CPU/GPU fusion).
50+
51+
52+
How to write a customized distributed optimizer with TorchScript support?
53+
-------------------------------------------------------------------------
54+
55+
The code below shows how to write a customized distributed optimizer given an existing local
56+
optimizer implementation, which unlocks the TorchScript benefits including GIL removal and
57+
performance improvement opportunities.
58+
59+
Suppose that you already have a local optimizer that is currently used during training,
60+
In this case we will use `quasi-hyperbolic momentum (QHM) <https://github.com/facebookresearch/qhoptim/blob/e81dea3f2765780cf4fbb90b87b22ba7604b8625/qhoptim/pyt/qhm.py#L12>`_
61+
as an example to show how to enable the TorchScript support, note that it also applies
62+
to any custom optimizers that inherits from ``torch.optim.Optimizer``.
63+
64+
First, we need to separate the computation and state management from the optimizer implementation,
65+
this is so that we could extract the computation part and make it a free function, which is
66+
TorchScript friendly. It has two benefits: 1. The computation logic becomes easier to inspect,
67+
it allows us to quickly turn the parameter update/computation part into TorchScript, and utilize
68+
TorchScript IR to do further optimizations (operator fusion, etc.) 2. Distributed Optimizer
69+
underlying is using a different mechanisms to get gradients and update parameters (we store
70+
gradients separately instead of directly populating the ``param.grad`` field during backward).
71+
Separating the computation allows distributed optimizer to enable the possibility of optimizer
72+
update in multithreaded mode, as it eliminates the possible race condition to ``param.grad``.
73+
74+
75+
::
76+
77+
import torch
78+
from torch import Tensor
79+
from typing import List
80+
81+
82+
def qhm_update(params: List[Tensor],
83+
dp_list: List[Tensor],
84+
momentum_buffer_list: List[Tensor],
85+
lr: float,
86+
nu: float,
87+
weight_decay: float,
88+
weight_decay_type: str,
89+
momentum: float):
90+
91+
for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list):
92+
if weight_decay != 0:
93+
if weight_decay_type == "grad":
94+
d_p.add_(weight_decay, p)
95+
elif weight_decay_type == "direct":
96+
p.mul_(1.0 - lr * weight_decay)
97+
else:
98+
raise ValueError("Invalid weight decay type provided")
99+
100+
momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p)
101+
102+
p.data.add_(-lr * nu, momentum_buffer)
103+
p.data.add_(-lr * (1.0 - nu), d_p)
104+
105+
106+
107+
Next we will define a distributed functional optimizer with TorchScript compatability to manage
108+
the optimizer states and calls into the TorchScript compatible update function we defined above.
109+
Note that a few conventions are different from normal custom optimizers: 1. We don't inherit
110+
``torch.optim.Optimizer`` as TorchScript does not support polymorphism 2. ``step`` takes gradients
111+
list instead of the loss closure.
112+
113+
::
114+
115+
import torch
116+
from torch import Tensor
117+
from typing import List, Optional, Dict
118+
119+
# define this as a TorchScript class
120+
@torch.jit.script
121+
class FunctionalQHM(object):
122+
def __init__(self,
123+
params: List[Tensor],
124+
lr: float,
125+
momentum: float,
126+
nu: float,
127+
weight_decay: float = 0.0,
128+
weight_decay_type: str = "grad"):
129+
if lr < 0.0:
130+
raise ValueError("Invalid learning rate: {}".format(lr))
131+
if momentum < 0.0:
132+
raise ValueError("Invalid momentum value: {}".format(momentum))
133+
if weight_decay < 0.0:
134+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
135+
if weight_decay_type not in ("grad", "direct"):
136+
raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))
137+
138+
self.defaults = {
139+
"lr": lr,
140+
"momentum": momentum,
141+
"nu": nu,
142+
"weight_decay": weight_decay,
143+
}
144+
self.weight_decay_type = weight_decay_type
145+
146+
# NOTE: we only have one param_group here and don't allow user to add additional
147+
# param group as it's not a common use case.
148+
self.param_group = {"params": params}
149+
150+
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
151+
152+
def step(self, gradients: List[Optional[Tensor]]):
153+
params = self.param_group['params']
154+
params_with_grad = []
155+
grads = []
156+
momentum_buffer_list: List[Tensor] = []
157+
158+
if len(params) != len(gradients):
159+
raise ValueError(
160+
"the gradients passed in does not equal to the size of the parameters!"
161+
+ f"Params length: {len(params)}. "
162+
+ f"Gradients length: {len(gradients)}"
163+
)
164+
165+
for param, gradient in zip(self.param_group['params'], gradients):
166+
if gradient is not None:
167+
params_with_grad.append(param)
168+
grads.append(gradient)
169+
state = self.state[param]
170+
state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format)
171+
momentum_buffer_list.append(state['momentum_buffer'])
172+
173+
# calls into the update function we just defined
174+
with torch.no_grad():
175+
qhm_update(params_with_grad,
176+
grads,
177+
momentum_buffer_list,
178+
self.defaults['lr'],
179+
self.defaults['nu'],
180+
self.defaults['weight_decay'],
181+
self.weight_decay_type,
182+
self.defaults['momentum'])
183+
184+
185+
186+
Finally, we register our newly defined distributed functional optimizer into the ``functional_optim_map``
187+
This is so that the ``DistributedOptimizer`` will try to pick up our custom implementation instead of the
188+
pre-defined default ones.
189+
190+
::
191+
192+
from torch.distributed.optim import DistributedOptimizer
193+
194+
DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM
195+
196+
Now you can use the ``QHM`` optimizer as normal in distributed training by passing it to
197+
`DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`_
198+
199+
200+
::
201+
202+
...
203+
remote_params_list = [...]
204+
dist_optim = DistributedOptimizer(
205+
QHM, remote_params_list, *args, **kwargs
206+
)
207+
208+
DistributedOptimizer will automatically transform the QHM optimizer into the ``FunctionalQHM`` under the hood,
209+
and enable the TorchScript support. This will unlock the performance that boosted by multithreaded training
210+
and also give more potentials for further improvements (i.e. TorchScript fusion, etc.)
211+
212+
Note that majority of PyTorch built-in optimizers are already using this methodology to speed up distributed
213+
training. If you see warning about some optimizers haven't been converted yet, you can write your own conversion
214+
by following this recipe.

recipes_source/recipes_index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
248248
:link: ../recipes/cuda_rpc.html
249249
:tags: Distributed-Training
250250

251+
.. customcarditem::
252+
:header: Distributed Optimizer with TorchScript support
253+
:card_description: How to enable TorchScript support for Distributed Optimizer.
254+
:image: ../_static/img/thumbnails/cropped/profiler.png
255+
:link: ../recipes/distributed_optim_torchscript.html
256+
:tags: Distributed-Training,TorchScript
257+
251258
.. End of tutorial card section
252259
253260
.. raw:: html
@@ -286,3 +293,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
286293
/recipes/distributed_rpc_profiling
287294
/recipes/zero_redundancy_optimizer
288295
/recipes/cuda_rpc
296+
/recipes/distributed_optim_torchscript

0 commit comments

Comments
 (0)