Demystifying Dtype Promotion in PyTorch
Summary
This article offers an insightful look into dtype promotion in PyTorch, explaining how different data types are handled during tensor operations. It covers the fundamental rules of dtype promotion, the specifics of how scalar values are integrated into tensor operations, and the role of TensorIterator in computing dtypes.
0. Introduction
Let’s start with code:
|
|
Have you ever wondered about the data types (dtype
) of output tensors in PyTorch? We’ll explore this topic and provide answers later in the article.
1. Basic Rules of Dtype Promotion
In PyTorch, when the dtypes of inputs in an arithmetic operation (such as add
, sub
, etc.) are different, dtype promotion occurs. This is based on the following criteria:
If the dtype of a scalar is of a higher category than that of a tensor (Note:
complex
>floating
>integral
>boolean
), the dtype is promoted to one that is large enough to contain all scalar values.If a zero-dimensional (0-dim) tensor operand has a higher category than dimensioned operands, it is promoted to a dtype that can hold the 0-dim tensor.
In cases where there are no higher-category 0-dim tensor operands, the dtype is promoted to one that can accommodate all dimensioned operands.
Special Cases: For operations like
div
, dividing an integer tensor by an integer scalar results in afloat
dtype.
2. PyTorch Implementation Details
Let’s delve into the PyTorch source code to understand how dtype promotion is implemented.
2.1 Wrapped Tensor
Consider the operation int_tensor + 5
, where 5
is a constant scalar. In this scenario, the scalar 5
is wrapped into a tensor with a dtype of int64
.
This wrapping approach enables the reuse of the add.Tensor
operator. As a result, there is no need to maintain separate add.Tensor
and add.Scalar
operators. (Note: In PyTorch, the add.Scalar
interface is not registered to the dispatcher and is therefore not used.)
Here’s how the scalar wrapping occurs:
|
|
The crucial step is _r.tensor(0)
, where the scalar is converted into a 0-dim tensor.
|
|
And the process of converting scalar to tensor is through fill
:
|
|
In scenarios where a C++ function (e.g., at::native::add_(...)
) gets called, the Scalar
is similarly wrapped.
|
|
In the kernel (“f(a, b) == f(b,a)”), it is possible to improve computational efficiency by removing the wrapped tensor and CPU scalar tensor from TensorIterator and treating them as ordinary constant values.
For example:
|
|
2.2 Computing the dtypes
The computation of dtypes occurs within the TensorIterator. If you’re unfamiliar with TensorIterator, I recommend reading my article introducing it here.
And in this article, we will focus on exploring the implementation of dtype promotion.
|
|
In compute_types
, PyTorch calculates the common_dtype_
based on the input tensors and configuration settings like promote_inputs_to_common_dtype_
. The resulting dtype is then stored in op.target_dtype
, which is later used in allocate_or_resize_outputs
.
To understand how PyTorch implements dtype promotion, let’s examine the compute_common_dtype
.
Note: The promote_inputs_to_common_dtype_
must be set to True
to enable this dtype inference mechanism(Typically, the configuration of TensorIterator is determined by macros such as BINARY_FLOAT_OP_CONFIG
).
|
|
For each tensor, PyTorch invokes update_result_type_state
to update the ResultTypeState. This state include three types of result dtypes: dimResult
(for normal tensors), zeroResult
(for 0-dim tensors that are not wrapped) and wrappedResult
(for wrapped tensors).
The at::native::result_type
function is then called to infer the common_dtype_
:
|
|
In most cases, the precedence order of the three result types is: dimResult
> zeroResult
> wrappedResult
.
If the higher result dtype is a bool
or the lower result dtype is a FloatingType
, the dtype promotion function promote_skip_undefined
is invoked:
|
|
In promote_skip_undefined
, PyTorch employs a lookup table to efficiently execute dtype promotion.
3. Review the answer
Having delved into the dtype promotion mechanism of PyTorch, let’s revisit and answer the questions posed earlier in the article.
|
|