Pytorch Compiler Introduction
Summary
This article introduces the Compiler
of PyTorch, starting with code examples to demonstrate the acceleration effects of compilation. Subsequently, basic knowledge related to PyTorch FX
, and finally, contents about TorchDynamo is presented, including Graph, adjustments to Python bytecode, Guard
, Cache
, etc.
Preamble: This document is based on Pytorch 2.1, TorchDynamo is a continuously iterating module, later versions may have different API, but the core idea remains the same.
Overview
torch.compile
is a feature introduced in Pytorch2.x for more accurately capturing the computation graph and accelerating program execution. It is written in Python, marking a gradual shift in Pytorch development from C++ to Python.
torch.compile
mainly relies on the following technologies:
- TorchDynamo (torch._dynamo): An internal API, using CPython’s Frame Evaluation API to safely capture PyTorch computation graphs.
- TorchInductor: The default
torch.compile
deep learning compiler, generating efficiently executable code for various backends. For NVIDIA and AMD GPUs, it’s primarily built on OpenAI Triton. - AOT Autograd (Ahead-Of-Time Autograd): Captures user-level code and backpropagation at compile-time. Generally, deep learning frameworks execute forward and backward operations at runtime, while AOT Autograd allows capturing backpropagation at compile-time, then using
TorchInductor
to accelerate forward computation and backpropagation.
Some common backends include:
- Support for both training and inference:
inductor
: The default TorchInductor backend.cudagraphs
: A Cuda computation graph backend withAOT Autograd
.ipex
: intel-extension-for-Pytorch, a CPU backend.onnxrt
: Training backend based on ONNX Runtime, CPU / GPU.
- Support for inference:
tensorrt
: onnx-tensorrt, using ONNX Runtime to run TensorRT for accelerated inference.tvm
: Using Apache TVM for accelerated inference.
Getting Started
|
|
In the code above, we compared the runtime of the original function with the compiled function, and we can see that, after the initial overhead (extra initialization or compilation time required for the first call), our function has been accelerated by more than twice, which is very beneficial in actual training scenarios where multiple repetitions are needed.
There are mainly two reasons for the acceleration, firstly, we achieved fusion through compile.
What is fusion? TorchInductor supports Triton kernel by default, and we can observe the generated Triton code by setting the environment variable TORCH_COMPILE_DEBUG=1
(code may vary with different hardware).
|
|
In the original function, for pointwise operations like cos
and sin
, x
needs to be read once, computed and written into a
, then a
is read again, computed and written into b
. Now through fusion, we have only performed one read using tl.load
to tmp0
, and one write using tl.store
. We know that in newer GPUs, the bottleneck lies in memory bandwidth (speed when accesses data) rather than computation (floating point operation speed), hence fusion provides a good performance optimization.
Secondly, inductor
also provides support for Cuda graphs. Cuda graphs capture a sequence of operations (such as kernel calls, memory copies, etc.) and save it as a graph. Based on this graph, the same sequence of operations can be executed multiple times, significantly reducing the overhead of launching operations. NVIDIA might also optimize the graph to reduce synchronization or improve memory access patterns, further increasing efficiency.
FX
FX, standing for PyTorch Flexible eXtensions, is something we need to understand before diving into Dynamo.
FX
is a sub-library of Pytorch, designed to assist developers in transforming nn.Module
model instances into IR (Intermediate Representation). IR
is a more structured and analyzable graph.
Developers can perform visual analysis, model transformation, and optimization (such as removing unnecessary operations, merging layers, etc.) based on IR. Once optimized, it can be converted back into Pytorch code or other formats through code generation, facilitating deployment to different platforms and backends.
FX primarily comprises three components: symbolic tracer, IR, and Python code generation. Let’s elucidate these three components with an example:
|
|
Symbolic Tracer inputs some fake values (also known as Proxies) to the model, and the operations on these proxies are recorded.
IR (Intermediate Representation) is a container that records operations during symbolic tracing, containing inputs, call sites (functions/methods/nn.Module
instances), and return values. For instance, the above code will generate the following IR:
|
|
Python Code Generation can help us create effective Python code that matches the semantics based on IR, as the example above would generate the following Python code:
|
|
These three components can be used in combination or separately (like using symbolic tracing alone for model analysis), serving as handy tools for developers.
Deep Dive to TorchDynamo
TorchDynamo is a Python-level JIT (just-in-time)
compiler, utilizing CPython’s Frame Evaluation API (PEP523) to rewrite Python bytecode and extract Pytorch operation sequences and form an FX Graph, which is then compiled using a specified backend. By creating FX Graph through bytecode analysis, and combining Python execution with the compiled backend, we ensure usability while maintaining good performance.
The image below explains the mechanism of torch.compile
:
Example
|
|
Executing the above code, we can get:
|
|
This output informs us that my_compiler
was called three times, generating three graphs:
- All content before the branch: Compute
x
and check ifb.sum()
is less than 0. - The True branch of
if
: Includesb = b * -1
andreturn x * b
. - The False branch of
if
: Just the return valuereturn x * b
.
What does Dynamo do?
To delve deeper into what Dynamo specifically does in the process above, we can add the following code to print more logs:
|
|
The output of the first graph is as follows:
|
|
From the output, it’s clear that Dynamo starts by tracing our function toy_example
, then compiles, generates the graph, and outputs it. Additionally, the output reveals changes in bytecode and the declaration of Guards.
Let’s first look at the bytecode:
In the original Python bytecode, operations like LOAD_FAST
are used to load values from local variables, LOAD_METHOD
and CALL_METHOD
are used to call methods, BINARY_ADD
and BINARY_MULTIPLY
are used for addition and multiplication operations, respectively.
Dynamo modifies the Python bytecode, replacing the calculations for x
and the check for b.sum() < 0
in the original bytecode with a call to the compiled __compiled_fn_0
function. Subsequently, based on the returned value, it calls the generated __resume_at_30_1
or __resume_at_38_2
, corresponding to the two branches in the original bytecode.
__resume_at_xx
functions come from the following template, and are used to continue executing code in the graph at the breakpoint.
|
|
By generating __resume_at_xx
, we force the function to execute in a new Python Frame, and recursively initiate Dynamo to execute the capture process again.
How to understand this recursion? When toy_example
is executed for the first time, Dynamo initiates a capture process, generating optimized bytecode, including __compiled_fn_0
and two resume
functions. When we enter a resume
function, Dynamo initiates a similar process to handle other possible branches within the resume
function, and so on, processing all the code.
Guard
The output above also includes Guard:
|
|
Here, if any Guard fails (meaning the optimized code is not safe or correct, or possibly failing due to different runtime conditions), the graph will be re-captured and re-compiled.
In this case, TENSOR_MATCH checks tensor object attributes like dtype
, shape
, device
, requires_grad
, dispatch_key
, ndim
, sizes
, strides
, etc. FUNCTION_MATCH checks the function object’s id(obj)
, and possibly id(type(obj))
etc., to ensure correct function calls.
Caching
An important factor for acceleration in the above example by Dynamo is Caching. While Caching is not a direct accelerator, it prevents re-compilation.
After modifying the Python bytecode, Dynamo caches it. Every time a new Frame is received for evaluation, Dynamo checks if objects referenced in the Frame have changed; if not, the cached user bytecode is directly used.
The process can be summarized as follows:
- Dynamo receives a Python Frame, which contains the current state and context information of the code.
- Dynamo optimizes the Python instructions, generating optimized bytecode.
- For the objects captured in (2), Dynamo creates tracking objects:
- tracking on an output graph (an internal implementation of
torch.fx.Tracer
) - Guard.
- tracking on an output graph (an internal implementation of
- Dynamo generates a
check_fn
function, which is used to check these Guard objects. - When the program runs into the associated code,
check_fn
is called to check the bytecode in the Cache. Ifcheck_fn
returns True, it is used directly; otherwise, optimized code is regenerated through re-compilation or Graph Break.