目录

Demystifying Dtype Promotion in PyTorch

本文深入探讨了 PyTorch 中的数据类型提升(dtype promotion)机制,包含 promotion 的基本规则、scalar 如何被 wrapped 成 tensor、TensorIterator 在计算数据类型时的作用等细节。

这篇文章使用O3-mini-high翻译,如有困惑请参考英文原文


我们从代码开始:

import torch

float_tensor = torch.ones(1, dtype=torch.float)
double_tensor = torch.ones(1, dtype=torch.double)
complex_float_tensor = torch.ones(1, dtype=torch.complex64)
complex_double_tensor = torch.ones(1, dtype=torch.complex128)
int_tensor = torch.ones(1, dtype=torch.int)
long_tensor = torch.ones(1, dtype=torch.long)
uint_tensor = torch.ones(1, dtype=torch.uint8)
double_tensor = torch.ones(1, dtype=torch.double)
bool_tensor = torch.ones(1, dtype=torch.bool)
long_zerodim = torch.tensor(1, dtype=torch.long)
int_zerodim = torch.tensor(1, dtype=torch.int)

>>> (int_tensor + 5).dtype
>>> (int_tensor + 5.5).dtype
>>> (int_tensor / 5).dtype
>>> (int_tensor + long_zerodim).dtype
>>> (long_tensor + int_tensor).dtype
>>> (bool_tensor + long_tensor).dtype
>>> (bool_tensor + uint_tensor).dtype
>>> (float_tensor + double_tensor).dtype
>>> (complex_float_tensor + complex_double_tensor).dtype
>>> (bool_tensor + int_tensor).dtype
>>> torch.add(long_tensor, float_tensor).dtype

你是否曾经好奇,为什么 PyTorch 中算术运算(比如 addsub 等)的输出张量会有不同的 dtype?本文将探讨这一主题,并在后续部分给出详细的答案。


在 PyTorch 中,当参与算术运算的输入张量具有不同的 dtype 时,会触发 dtype 提升(promotion)。提升规则主要基于以下几个准则:

  • 标量与张量:如果一个标量的 dtype 所处的类别比张量更高(注意:complex > floating > integral > boolean),则最终的 dtype 将被提升到足以容纳所有标量值的类型。

  • 零维张量:如果参与运算的 0 维张量(即标量张量)的类别高于其他有维度的张量,其 dtype 会被提升为能够存储 0 维张量的类型。

  • 多个有维度张量:如果没有更高类别的 0 维张量,则会提升为能够容纳所有有维度张量的 dtype。

  • 特殊情况:例如,对于 div(除法)操作,当整数张量除以整数标量时,结果会被提升为 float 类型。


接下来,我们深入 PyTorch 源码,看看它是如何实现 dtype 提升的。

考虑操作 int_tensor + 5,这里的 5 是一个常量标量。此时,标量 5 会被包装成一个 dtype 为 int64 的张量。
这种包装方式使得我们可以复用 add.Tensor 运算符,从而不必单独维护 add.Tensoradd.Scalar 两个版本。(需要注意的是,PyTorch 中的 add.Scalar 接口没有注册到 dispatcher 上,因此并未实际使用。)

下面展示了标量包装的具体实现:

// torch/csrc/autograd/generated/python_variable_methods.cpp
static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  const Tensor& self = THPVariable_Unpack(self_);
  static PythonArgParser parser({
    "add(Scalar alpha, Tensor other)|deprecated",
    "add(Tensor other, *, Scalar alpha=1)",
  }, /*traceable=*/true);

  ParsedArgs<2> parsed_args;
  auto _r = parser.parse(self_, args, kwargs, parsed_args);
  // ...
  switch (_r.idx) {
    case 0: {
      // [已废弃] aten::add(Tensor self, Scalar alpha, Tensor other) -> Tensor
      // ...
    }
    case 1: {
      // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
      auto dispatch_add = [](const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) -> at::Tensor {
        pybind11::gil_scoped_release no_gil;
        return self.add(other, alpha);
      };
      return wrap(dispatch_add(self, _r.tensor(0), _r.scalar(1)));
    }
  }
}

关键在于调用了 _r.tensor(0),此处将标量转换成了一个 0 维张量。

// torch/csrc/utils/python_arg_parser.h
inline at::Tensor PythonArgs::tensor(int i) {
  // ...
  return tensor_slow(i);
}

// torch/csrc/utils/python_arg_parser.cpp
at::Tensor PythonArgs::tensor_slow(int i) {
  PyObject* obj = args[i];
  if (!obj) {
    return at::Tensor();
  }
  // ...
  bool save_symint = false;
  at::Scalar scalar;
  if (PyBool_Check(obj)) {
    scalar = at::Scalar(THPUtils_unpackBool(obj));
  } else if (THPUtils_checkLong(obj)) {
    scalar = at::Scalar(THPUtils_unpackLong(obj));
  } else if (THPUtils_checkDouble(obj)) {
    scalar = at::Scalar(THPUtils_unpackDouble(obj));
  } // ... 其他 dtype 的处理 ...
  // ...
  at::Tensor tensor = scalar_to_tensor(scalar);
  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
  // ...
  return tensor;
}

而将标量转换成张量的过程是通过 fill 完成的:

// torch/include/ATen/ScalarOps.h
inline at::Tensor scalar_to_tensor(
    const Scalar& s,
    const Device device = at::kCPU) {
  // 针对 CPU 标量张量的快速路径
  if (device == at::kCPU) {
    return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
  }
  // ...
}

// aten/src/ATen/ScalarOps.cpp
Tensor scalar_tensor_static(const Scalar& s, c10::optional<ScalarType> dtype_opt, c10::optional<Device> device_opt) {
  // ...
  Tensor result = at::detail::empty_cpu(
      {}, dtype_opt, c10::nullopt, device_opt, c10::nullopt, c10::nullopt);
  scalar_fill(result, s);
  return result;
}

在某些 C++ 函数(例如 at::native::add_(...))被调用时,Scalar 同样会被包装。

// aten/src/ATen/native/BinaryOps.cpp
Tensor& add_(Tensor& self, const Scalar& other, const Scalar& alpha) {
  return self.add_(wrapped_scalar_tensor(other), alpha);
}

在内核层面(例如 f(a, b) == f(b, a) 的情况),可以通过剥除包装张量和 CPU 标量张量,从而将其视为普通常量,以提升计算效率。示例如下:

// aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu
void bitwise_and_kernel_cuda(TensorIteratorBase& iter) {
  AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_cuda", [&]() {
    BitwiseAndFunctor<scalar_t> f;
    opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, f);
  });
}

template <typename scalar_t, typename return_t = scalar_t, typename func_t>
void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
  // ...
  if (iter.is_cpu_scalar(1)) {
    scalar_val = iter.scalar_value<opmath_arg_t>(1);
    iter.remove_operand(1);
    device_guard.reset_device(iter.device(1));
  } else if (iter.is_cpu_scalar(2)) {
    scalar_val = iter.scalar_value<opmath_arg_t>(2);
    iter.remove_operand(2);
  }

  if (iter.ninputs() == 2) {
    gpu_kernel(iter, BinaryFunctor<scalar_t, scalar_t, return_t, func_t>(f));
  } else {
    AUnaryFunctor<scalar_t, scalar_t, return_t, func_t> unary_f(f, scalar_val);
    gpu_kernel(iter, unary_f);
  }
}

数据类型的计算主要发生在 TensorIterator 内部。如果你不熟悉 TensorIterator,建议先阅读我之前介绍它的文章 这里

在这篇文章中,我们将重点探讨 dtype 提升的实现过程。

// aten/src/ATen/TensorIterator.cpp
void TensorIteratorBase::build(TensorIteratorConfig& config) {
  // ...
  // 计算结果张量的 dtype 和设备
  compute_types(config);
  // ...
}

void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
  common_dtype_ = ScalarType::Undefined;
  ScalarType output_dtype = ScalarType::Undefined;
  bool has_different_input_dtypes = false;
  bool has_undefined_outputs = false;

  for (auto& op : operands_) {
    if (!op.is_type_defined()) {
      // ...
      if (config.static_dtype_.has_value()) {
        op.target_dtype = config.static_dtype_.value();
      } else {
        has_undefined_outputs = true;
      }
      // ...
    }
    // ... 

    if (!op.is_output) {
      // 判断输入张量是否存在不同的 dtype
      if (op.target_dtype != common_dtype_) {
        if (common_dtype_ == ScalarType::Undefined) {
          common_dtype_ = op.target_dtype;
        } else {
          has_different_input_dtypes = true;
        }
      }
    } else {
      // 判断输出张量是否存在不同的 dtype
      if (op.target_dtype != output_dtype) {
        if (output_dtype == ScalarType::Undefined) {
          output_dtype = op.target_dtype;
        }
        // ...
      }
    }
  }

  // ...

  if (!has_undefined_outputs && !config.check_all_same_device_ &&
      !config.promote_inputs_to_common_dtype_ && !config.cast_common_dtype_to_outputs_ &&
      !config.enforce_safe_casting_to_output_) {
    // 如果无法推断出 common_dtype_ 则置为 Undefined
    common_dtype_ = has_different_input_dtypes ? ScalarType::Undefined : common_dtype_;
    return;
  }

  // 如果需要,计算一个公共 dtype
  if ((has_different_input_dtypes || all_ops_are_scalars_) && config.promote_inputs_to_common_dtype_) {
    common_dtype_ = compute_common_dtype();
  }

  // 对于类似 `div` 操作,将整数输入提升为默认的 float 类型
  if (config.promote_integer_inputs_to_float_ &&
      c10::isIntegralType(common_dtype_, /*includeBool=*/true)) {
    common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype());
  }

  // ...
  for (auto& op : operands_) {
    bool is_type_defined = op.is_type_defined();
    if (!is_type_defined) {
      op.target_dtype = common_dtype_;
    }
    // ...
  }
}

compute_types 中,PyTorch 根据输入张量及配置参数(例如 promote_inputs_to_common_dtype_)计算出 common_dtype_,并将该结果存储到 op.target_dtype 中,后续在 allocate_or_resize_outputs 中使用。

为了更深入理解 dtype 提升的实现,我们来看看 compute_common_dtype 的具体逻辑:

// aten/src/ATen/TensorIterator.cpp
ScalarType TensorIteratorBase::compute_common_dtype() {
  at::native::ResultTypeState state = {};
  for (const auto& op : operands_) {
    if (op.is_output) {
      continue;
    }
    state = at::native::update_result_type_state(op.tensor(), state);
  }
  common_dtype_ = at::native::result_type(state);
  TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined);
  return common_dtype_;
}

在更新每个张量的结果状态时(ResultTypeState),PyTorch 会区分三种情况:

  • dimResult:用于普通(有维度)张量;
  • zeroResult:用于未包装的 0 维张量;
  • wrappedResult:用于包装后的 0 维张量。

at::native::result_type 函数会根据这三种结果状态,推导出最终的 common_dtype_

// aten/src/ATen/native/TypeProperties.cpp
ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state) {
  if (!tensor.defined()) {
    return in_state;
  }
  ResultTypeState new_state = in_state;
  ScalarType current = tensor.scalar_type();
  if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
    // 对于包装张量,使用默认的 complex/float 类型
    if(isComplexType(current)) {
      // 默认:complex<float>
      current = typeMetaToScalarType(at::get_default_complex_dtype());
    }
    else if(isFloatingType(current)) {
      // 默认:float
      current = typeMetaToScalarType(at::get_default_dtype());
    }
  }
  if ( tensor.dim() > 0 ) {
    // 普通张量
    new_state.dimResult = promote_skip_undefined(in_state.dimResult, current);
  } else if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
    // 包装张量(标量)
    new_state.wrappedResult = promote_skip_undefined(in_state.wrappedResult, current);
  } else {
    // 非包装的 0 维张量
    new_state.zeroResult = promote_skip_undefined(in_state.zeroResult, current);
  }
  return new_state;
}

// torch/include/ATen/native/TypeProperties.h
struct ResultTypeState {
  c10::ScalarType dimResult = ScalarType::Undefined;
  c10::ScalarType wrappedResult = ScalarType::Undefined;
  c10::ScalarType zeroResult = ScalarType::Undefined;
};

接下来,通过调用 at::native::result_type 得到最终的 common_dtype_

// aten/src/ATen/native/TypeProperties.cpp
ScalarType result_type(const ResultTypeState& in_state) {
  return combine_categories(in_state.dimResult, combine_categories(in_state.zeroResult, in_state.wrappedResult));
}

static inline ScalarType combine_categories(ScalarType higher, ScalarType lower) {
  if(isComplexType(higher)) {
    return higher;
  } else if (isComplexType(lower)) {
    // 如果 higher 为浮点类型,则保持其数值类型
    if (isFloatingType(higher)) {
      return toComplexType(higher);
    }
    // 如果输入为整数,则 lower(complex 类型)优先
    return lower;
  } else if (isFloatingType(higher)) {
    return higher;
  }
  if (higher == ScalarType::Bool || isFloatingType(lower)) {
    return promote_skip_undefined(higher, lower);
  }
  if (higher != ScalarType::Undefined) {
    return higher;
  }
  return lower;
}

在大多数情况下,三种结果状态的优先级为:dimResult > zeroResult > wrappedResult
当更高的结果 dtype 为 bool 或者更低的结果 dtype 为浮点类型时,将调用 promote_skip_undefined 进行 dtype 提升:

// aten/src/ATen/native/TypeProperties.cpp
static inline ScalarType promote_skip_undefined(ScalarType a, ScalarType b) {
  if (a == ScalarType::Undefined) {
    return b;
  }
  if (b == ScalarType::Undefined) {
    return a;
  }
  return promoteTypes(a, b);
}

// c10/core/ScalarType.cpp
constexpr auto u1 = ScalarType::Byte;
constexpr auto i1 = ScalarType::Char;
constexpr auto i2 = ScalarType::Short;
constexpr auto i4 = ScalarType::Int;
constexpr auto i8 = ScalarType::Long;
constexpr auto f2 = ScalarType::Half;
constexpr auto f4 = ScalarType::Float;
constexpr auto f8 = ScalarType::Double;
constexpr auto c2 = ScalarType::ComplexHalf;
constexpr auto c4 = ScalarType::ComplexFloat;
constexpr auto c8 = ScalarType::ComplexDouble;
constexpr auto b1 = ScalarType::Bool;
constexpr auto bf = ScalarType::BFloat16;
constexpr auto ud = ScalarType::Undefined;

ScalarType promoteTypes(ScalarType a, ScalarType b) {
  // 此处依据 NumPy 的 promote_types 生成
  if (a == ud || b == ud) {
    return ScalarType::Undefined;
  }
  if (a == b) {
    return a;
  }

  // ...

  auto ix_a = dtype2index[static_cast<int64_t>(a)];
  auto ix_b = dtype2index[static_cast<int64_t>(b)];

  // 该查找表与 index2dtype 保持一致
  static constexpr std::array<std::array<ScalarType, index2dtype.size()>, index2dtype.size()>
      _promoteTypesLookup = {{
      /*        u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1  bf*/
      /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf},
      /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf},
      /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf},
      /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf},
      /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf},
      /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4},
      /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4},
      /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8},
      /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4},
      /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4},
      /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8},
      /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf},
      /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf},
  }};
  // clang-format on
  return _promoteTypesLookup[ix_a][ix_b];
}

promote_skip_undefined 中,PyTorch 借助查找表来高效完成 dtype 的提升。


在详细探讨了 PyTorch 的 dtype 提升机制后,我们回顾一下最初提出的各个例子,并给出答案:

>>> (int_tensor + 5).dtype
# 5 被包装成 int64 张量,但其优先级不高于有维度张量,因此结果为 int32
torch.int32

>>> (int_tensor + 5.5).dtype
# 5.5 被包装为 double 张量,经过 `update_result_type_state` 获取默认的 float 类型,
# 最后提升为 float
torch.float32

>>> (int_tensor / 5).dtype
# 5 被包装为 long 张量,初步计算得到 int 类型,但由于 `div` 操作将整数提升为 float,
# 最终输出为 float
torch.float32

>>> (int_tensor + long_zerodim).dtype
# 0 维张量的优先级低于 int_tensor,因此无 dtype 提升
torch.int32

>>> (long_tensor + int_tensor).dtype
# 提升为 long 类型
torch.int64

>>> (bool_tensor + long_tensor).dtype
# 提升为 long 类型
torch.int64

>>> (bool_tensor + uint_tensor).dtype
# 提升为 uint8 类型
torch.uint8

>>> (float_tensor + double_tensor).dtype
# 提升为 double 类型
torch.float64

>>> (complex_float_tensor + complex_double_tensor).dtype
# 提升为 complex128 类型
torch.complex128

>>> (bool_tensor + int_tensor).dtype
# 提升为 int 类型
torch.int32

>>> torch.add(long_tensor, float_tensor).dtype
# 提升为 float 类型
torch.float32