이전 포스팅에서 horovod class 중 하나인 _DistributedOptimizer()
를 실행할 때, 추가적으로 _register_hook()
을 실행한다는 것을 파악했다. 본 포스팅에서는 위 내용을 다룰 예정이다.
def _register_hooks()
hook()
을 등록한다는게 조금 신기하여 추가적으로 분석을 진행했다.
def _register_hooks(self):
if self._groups is not None:
p_list = []
# Get list of parameters with grads
for param_group in self.param_groups:
for p in param_group['params']:
if p.requires_grad:
p_list.append(p)
# To ensure parameter order and group formation is consistent, broadcast p_list order
# from rank 0 and use for every worker
p_list_names = [self._parameter_names.get(p) for p in p_list]
p_list_names = broadcast_object(p_list_names, root_rank=0, process_set=self.process_set)
p_list = sorted(p_list, key=lambda p: p_list_names.index(self._parameter_names.get(p)))
# Form groups
if isinstance(self._groups, list):
p_groups = []
grouped_id = set()
p_list_ids = [id(p) for p in p_list]
for group in self._groups:
p_groups.append([p for p in group if id(p) in p_list_ids])
for p in p_groups[-1]:
grouped_id.add(id(p))
for p in p_list:
if id(p) not in grouped_id:
p_groups.append([p])
else:
p_groups = split_list(p_list, self._groups)
p_groups = [tuple(p) for p in p_groups]
for group in p_groups:
for p in group:
self._p_to_group[p] = group
self._group_counts[group] = 0
for param_group in self.param_groups:
for p in param_group['params']:
if p.requires_grad:
self._requires_update.add(p)
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_hook(p))
self._grad_accs.append(grad_acc)
위 과정을 쉽게 요약하면 다음과 같다.
p_list
를 초기화p_list
에 appendp_list_names
에 p_list
에 존재하는 파라미터의 이름을 넣는다.p_list_names
에 broadcast_object
함수의 리턴값을 넣는다.p_list
를 sorting 한 값을 p_list
에 넣는다.다음과 같은 궁금증이 생겼다.
self._requires_update
은 어떤 역할을 할까expand_as
는 어떤 함수일까grad_fn
은 무엇일까grad_acc
은 무엇일까self.param_groups
에서 굳이 'param' 으로 parameter 값을 저장하는 이유가 무엇일까self._requires_update
self._requires_update
는 paramter update를 필요로 하는 모델의 파라미터의 집합이다. 처음에 __init__()
에서 set()
으로 초기화 된다. 사용되는 범위는 (1) gradient update 를 요구로 하는 paramter의 저장, (2) allreduce가 complete 된 parameter와 비교를 통해 모든 paramter 가 sync가 완료되었는지 여부 판단이다.
set()
set()
함수는 여러 item 들을 {}
형태로 저장하는 method 이다. list
와 다르게 중괄호를 사용하며, dict
과는 다르게 key:value
형태가 아닌 key
값만 가지고 있다.
set() vs. list()
위 둘의 차이점은 []
를 사용하느냐, {}
를 사용하느냐 이다. {}
를 사용하면 데이터가 unordered
형태로 저장된다.
set() vs. dict()
위 둘의 차이점은 값을 key:value
형태로 저장하느냐 key
형태로 저장하느냐의 차이이다. unordered
데이터라는 것은 동일하나, set()
은 unindexed
라는 특성을 추가적으로 가지고 있다.
self.param_groups
self.param_groups
는 model의 파라미터 값을 저장하는 변수이다. Pytorch에서 파라미터는 모두 torch.Tensor()
라는 클래스로 표현된다. input 값은 물론, gradient 값 또한 동일한 클래스로 표현된다. 위에 표현된 전체함수에서 눈여겨 볼 부분은 다음의 코드이다. 이 코드는 torch.Tensor()
가 파라미터 값도 가지고 있지만 gradient 값도 함께 가지고 있다는 것을 의미한다.
grad_acc = p_tmp.grad_fn.next_functions[0][0]
analysis of torch.Tensor()
더 많은 분석을 위해 다음의 소스코드를 짜봤다.
import torch
data_cuda = torch.rand(4,4).cuda()
grad_cuda = torch.rand(4,4, requires_grad=True).cuda()
param_cuda = torch.nn.Linear(4,4, bias=False).cuda()
param_named = param_cuda.parameters()
print(data_cuda)
print(grad_cuda)
print(param_cuda)
for p in param_named:
print(p)
tensor([[0.4663, 0.5613, 0.8164, 0.6831],
[0.0765, 0.3705, 0.8062, 0.9530],
[0.8068, 0.9741, 0.6489, 0.7569],
[0.8212, 0.0679, 0.0354, 0.4271]], device='cuda:0')
tensor([[0.1298, 0.5248, 0.6926, 0.7999],
[0.7081, 0.7787, 0.6747, 0.7832],
[0.7012, 0.6970, 0.7179, 0.7022],
[0.2410, 0.5589, 0.0965, 0.2175]], device='cuda:0',
grad_fn=<ToCopyBackward0>)
Linear(in_features=4, out_features=4, bias=False)
Parameter containing:
tensor([[-0.3110, 0.1802, -0.0075, -0.1089],
[ 0.0464, -0.4302, 0.2586, -0.2815],
[-0.0306, -0.3935, -0.0750, -0.4062],
[-0.3453, 0.4289, 0.1087, 0.1371]], device='cuda:0',
requires_grad=True)
의외의 결과를 확인했다. grad_cuda
와 param_cuda
는 이론상 동일해야하는데, 다음의 두 가지 차이점이 있었다.
print()
된 결과가 다름 - Parameter containing 이 추가grad_fn
이라는 용어가 포함됨첫번째 사항 에 대한 이유는 nn.linear 를 구성하는 nn.Parameter(torch.Tensor)
가 다음과 같은 표현을 가지고 있다.
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
두번째 사항 의 원인은 cuda()
때문인 것 같다. cuda()
를 사용하여 gpu 를 활용할 경우, nn.Parameter 기반의 클래스와 다르게 requires_grad
대신 grad_fn
를 출력한다. 왜 그런것인지 추가 분석을 해보겠다.
data = torch.rand(4, 4, requires_grad=True)
data_cuda = torch.rand(4, 4, requires_grad=True).cuda()
print(data)
print(data_cuda)
tensor([[-0.0047, -1.2327, 0.0741, 0.1372],
[-1.0705, 0.7324, 0.5281, -0.0609],
[-0.9813, -0.9194, -0.7792, 1.2773],
[ 0.1200, 0.9547, -1.9466, 0.4951]], requires_grad=True)
tensor([[-0.1016, -0.5323, 2.1495, 0.3639],
[-1.8335, 0.8463, -1.3403, 0.2063],
[ 0.0133, -0.9961, 0.4323, -0.1578],
[-1.1304, -0.7580, 1.3594, -0.1112]], device='cuda:0',
grad_fn=<ToCopyBackward0>)
torch/_tensor.py
위 디렉토리에 torch.Tensor 클래스가 정의되어 있다. 나의 첫 목적은, 클래스를 print 할 때 나타나는 값인 __repr__()
을 찾는 것이다.
class Tensor(torch._C._TensorBase):
...
def __repr__(self, *, tensor_contents=None):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__repr__, (self,), self,
tensor_contents=tensor_contents)
# All strings are unicode in Python 3.
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
여기서 리턴하는 torch._tensor_str._str
이 뭔지 알아보겠다.
torch/_tensor_str.py
위 디렉토리에서 소스코드를 찾을 수 있었다. self
값과 tensor_contents
를 파라미터로 전달한다. 위 함수는 추가적으로 _str_intern(self, tensor_contents=tensor_contents)
함수를 호출하는데, 아주 함수가 복잡하다. 그리고 추가적으로 다음의 궁금증이 생겼다.
with torch.no_grad()
를 호출하는데 이유가 무엇일까.def _str(self, *, tensor_contents=None):
with torch.no_grad():
return _str_intern(self, tensor_contents=tensor_contents)
_str_intern()
code block 1
prefix를 채우는 단계다. 만약 위 경우처럼 plain tensor 일 경우 'tensor(' 를 삽입한다.
def _str_intern(inp, *, tensor_contents=None):
is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
if inp.is_nested:
prefix = "nested_tensor("
elif is_plain_tensor:
prefix = 'tensor('
else:
prefix = f"{type(inp).__name__}("
indent = len(prefix)
suffixes = []
custom_contents_provided = tensor_contents is not None
if custom_contents_provided:
tensor_str = tensor_contents
code block 2
이해를 하지 못했으나, 본 목적하고는 달라 분석을 생략했다.
# This is used to extract the primal value and thus disable the forward AD
# within this function.
# TODO(albanD) This needs to be updated when more than one level is supported
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
code block 3
신기한 점은 suffix 부터 값을 채워나간다는 점이고, print 가 텐서 연산을 요구하기 때문에, , xla/lazy 기반의 tensor가 compilation을 초래한다는 사실이다. 따라서 print 이전에 모든 텐서를 cpu에 복사하는 모습이다.
# Note [Print tensor device]:
# A general logic here is we only print device when it doesn't match
# the device specified in default tensor type.
# Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
# torch._C._get_default_device() only returns either cpu or cuda.
# In other cases, we don't have a way to set them as default yet,
# and we should always print out device for them.
if self.device.type != torch._C._get_default_device()\
or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index)\
or (self.device.type == 'mps'):
suffixes.append('device=\'' + str(self.device) + '\'')
# Tensor printing performs tensor operations like slice, indexing, etc to make it in a
# representable format. These operations on xla/lazy tensor results in compilations. Hence,
# to avoid compilations, copying the tensor to cpu before printing.
if self.device.type == 'xla' or self.device.type == 'lazy':
self = self.to('cpu')
code block 4
주어진 텐서가 다른 타입을 가지고 있는지 검사하고, 다른 타입을 가지고 있다면 (i.e., sparse) print를 다르게 진행하는 과정이다. 생략하겠다.
# TODO: add an API to map real -> complex dtypes
_default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool)
if self.is_sparse:
suffixes.append('size=' + str(tuple(self.shape)))
suffixes.append('nnz=' + str(self._nnz()))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
if not custom_contents_provided:
indices_prefix = 'indices=tensor('
indices = self._indices().detach()
indices_str = _tensor_str(indices, indent + len(indices_prefix))
if indices.numel() == 0:
indices_str += ', size=' + str(tuple(indices.shape))
values_prefix = 'values=tensor('
values = self._values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
elif self.layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
suffixes.append('size=' + str(tuple(self.shape)))
suffixes.append('nnz=' + str(self._nnz()))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
if not custom_contents_provided:
compressed_indices_method, plain_indices_method = {
torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
}[self.layout]
if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
cdimname, pdimname = 'row', 'column'
else:
cdimname, pdimname = 'column', 'row'
compressed_indices_prefix = f'c{cdimname[:3]}_indices=tensor('
compressed_indices = compressed_indices_method(self).detach()
compressed_indices_str = _tensor_str(compressed_indices, indent + len(compressed_indices_prefix))
if compressed_indices.numel() == 0:
compressed_indices_str += ', size=' + str(tuple(compressed_indices.shape))
plain_indices_prefix = f'{pdimname[:3]}_indices=tensor('
plain_indices = plain_indices_method(self).detach()
plain_indices_str = _tensor_str(plain_indices, indent + len(plain_indices_prefix))
if plain_indices.numel() == 0:
plain_indices_str += ', size=' + str(tuple(plain_indices.shape))
values_prefix = 'values=tensor('
values = self.values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = compressed_indices_prefix + compressed_indices_str + '),\n' + ' ' * indent +\
plain_indices_prefix + plain_indices_str + '),\n' + ' ' * indent +\
values_prefix + values_str + ')'
elif self.is_quantized:
suffixes.append('size=' + str(tuple(self.shape)))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
suffixes.append('quantization_scheme=' + str(self.qscheme()))
if self.qscheme() == torch.per_tensor_affine or self.qscheme() == torch.per_tensor_symmetric:
suffixes.append('scale=' + str(self.q_scale()))
suffixes.append('zero_point=' + str(self.q_zero_point()))
elif self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric \
or self.qscheme() == torch.per_channel_affine_float_qparams:
suffixes.append('scale=' + str(self.q_per_channel_scales()))
suffixes.append('zero_point=' + str(self.q_per_channel_zero_points()))
suffixes.append('axis=' + str(self.q_per_channel_axis()))
if not custom_contents_provided:
tensor_str = _tensor_str(self.dequantize(), indent)
elif self.is_nested:
if not custom_contents_provided:
def indented_str(s, indent):
return "\n".join(f" {line}" for line in s.split("\n"))
strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0))
tensor_str = f"[\n{strs}\n]"
else:
if self.is_meta:
suffixes.append('size=' + str(tuple(self.shape)))
if self.dtype != torch.get_default_dtype():
suffixes.append('dtype=' + str(self.dtype))
# TODO: This implies that ellipses is valid syntax for allocating
# a meta tensor, which it could be, but it isn't right now
if not custom_contents_provided:
tensor_str = '...'
else:
if self.numel() == 0 and not self.is_sparse:
# Explicitly print the shape if it is not (0,), to match NumPy behavior
if self.dim() != 1:
suffixes.append('size=' + str(tuple(self.shape)))
# In an empty tensor, there are no elements to infer if the dtype
# should be int64, so it must be shown explicitly.
if self.dtype != torch.get_default_dtype():
suffixes.append('dtype=' + str(self.dtype))
if not custom_contents_provided:
tensor_str = '[]'
else:
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
if not custom_contents_provided:
if self.layout != torch.strided:
tensor_str = _tensor_str(self.to_dense(), indent)
else:
tensor_str = _tensor_str(self, indent)
if self.layout != torch.strided:
suffixes.append('layout=' + str(self.layout))
code block 5
궁금한 사항을 여기서 찾을 수 있었다. 만약 self.grad_fn 이 None 이 아니면, requires_grad 대신 grad_fn을 출력하도록 되어있다. 그렇다면 cuda()
가 어떻게 grad_fn
을 초기화 하는 것일까.
# Use inp here to get the original grad_fn and not the one generated by the forward grad
# unpacking.
if inp.grad_fn is not None:
name = type(inp.grad_fn).__name__
if name == 'CppFunction':
name = inp.grad_fn.name().rsplit('::', 1)[-1]
suffixes.append('grad_fn=<{}>'.format(name))
elif inp.requires_grad:
suffixes.append('requires_grad=True')
if self.has_names():
suffixes.append('names={}'.format(self.names))
if tangent is not None:
suffixes.append('tangent={}'.format(tangent))
string_repr = _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
# Check if this instance is flagged as a parameter and change the repr accordingly.
# Unfortunately, this function has to be aware of this detail.
# NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
# this should be done for those as well to produce a valid repr.
if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
string_repr = f"Parameter({string_repr})"
return string_repr
grad_fn vs. cuda()
앞서 cuda()
가 grad_fn
을 초기화 시키는 것을 코드로 확인했다. 그렇다면 grad_fn
은 무엇이고 왜 초기화 하는 것일까. 이 점에 대해 추가 분석을 진행하겠다. torch.Tensor()
는 다음과 같이 정의되어 있다.
class Tensor(torch._C._TensorBase):
이에 대한 super class 는 아래와 같이 정의되어있다.
torch/_C/__init__.pyi
여기서 pyi
확장자에서 i
는 인터페이스를 뜻한다. 정의는 다음과 같다.
확인해보면, 모든 변수에 self 가 존재하지 않고, 함수 또한 비워져있다.
class _TensorMeta(type):
pass
# Defined in torch/csrc/autograd/python_variable.cpp
class _TensorBase(metaclass=_TensorMeta):
requires_grad: _bool
shape: Size
data: Tensor
names: List[str]
device: _device
dtype: _dtype
layout: _layout
real: Tensor
imag: Tensor
T: Tensor
H: Tensor
mT: Tensor
mH: Tensor
ndim: _int
output_nr: _int
_version: _int
_base: Optional[Tensor]
_cdata: _int
grad_fn: Any
_grad_fn: Any
_grad: Optional[Tensor]
_backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
def __abs__(self) -> Tensor: ...
def __add__(self, other: Any) -> Tensor: ...
...
torch/csrc/autograd/python_variable.cpp
C++ 까지 갈 줄은 몰랐는데, 끝까지 가보겠다. https://github.com/pytorch/pytorch.git 에서 소스코드를 clone 한 뒤 분석을 진행했다. 분석을 진행하기 앞서, autograd 문서를 참조했다.
/csrc/autograd/README.md
Our general model is that for any key data type that autograd manipulates, there are two implementations: a C++ type and a Python object type. For example, consider variables in autograd: we have both Variable in variable.h (the C++ type) and THPVariable in python_variable.h (the Python type.) (By the way, THP stands for TorcH Python, not to be confused with THPP, TorcH C++). Variable contains the payload of a variable, while THPVariable just contains a shared_ptr reference to Variable, as well as references to other Python objects which the Python runtime needs to know about. A lot of data accessor implementations in python_variable.cpp simply reach through to the underlying Variable and return the appropriate value.
일부만 분석해보자면, Variable
은 C++ type 이며, THPVariable
은 TorcH Python 을 의미하는 python type, 그리고 THPP
는 TorcH C++ type이다.
(getter)THPVariable_get_grad_fn
여기서 grad_fn
이 (getter)THPVariable_get_grad_fn
으로 정의 되었음을 확인했다.
// properties are registered here because we are currently only able to bind
// them manually. TODO: make declarable in native_functions
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyGetSetDef THPVariable_properties[] = {
{"_python_dispatch",
(getter)THPVariable_get_python_dispatch,
nullptr,
nullptr,
nullptr},
{"T", (getter)PropertyT::getter, nullptr, nullptr, nullptr},
{"H", (getter)PropertyH::getter, nullptr, nullptr, nullptr},
{"mT", (getter)PropertymT::getter, nullptr, nullptr, nullptr},
{"mH", (getter)PropertymH::getter, nullptr, nullptr, nullptr},
{"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr},
{"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr},
{"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr},
{"_grad_fn",
(getter)THPVariable_get_grad_fn,
(setter)THPVariable_set_grad_fn,
nullptr,
nullptr},
{"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr},
{"retains_grad",
(getter)THPVariable_retains_grad,
nullptr,
nullptr,
nullptr},
{"data",
(getter)PropertyData::getter,
(setter)THPVariable_set_data,
nullptr,
nullptr},
{"_grad",
(getter)PropertyGrad::getter,
(setter)THPVariable_set_grad,
nullptr,
nullptr}, // Allows the python class to override .grad
{"grad",
(getter)PropertyGrad::getter,
(setter)THPVariable_set_grad,
nullptr,
nullptr},
{"_base", (getter)THPVariable_get_base, nullptr, nullptr, nullptr},
{"volatile",
(getter)THPVariable_get_volatile,
(setter)THPVariable_set_volatile,
nullptr,
nullptr},
{"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr},
{"requires_grad",
(getter)THPVariable_get_requires_grad,
(setter)THPVariable_set_requires_grad,
nullptr,
nullptr},
{"_backward_hooks",
(getter)THPVariable_get_backwards_hooks,
(setter)THPVariable_set_backwards_hooks,
nullptr,
nullptr},
{"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_cpu", (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr},
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
{"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_sparse_csr",
(getter)THPVariable_is_sparse_csr,
nullptr,
nullptr,
nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_mps", (getter)THPVariable_is_mps, nullptr, nullptr, nullptr},
{"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized",
(getter)THPVariable_is_quantized,
nullptr,
nullptr,
nullptr},
{"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr},
{"is_nested", (getter)THPVariable_is_nested, nullptr, nullptr, nullptr},
{"_has_symbolic_sizes_strides",
(getter)THPVariable_has_symbolic_sizes_strides,
nullptr,
nullptr,
nullptr},
{"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr},
{"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
{"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
{"names",
(getter)THPVariable_get_names,
(setter)THPVariable_set_names,
nullptr,
nullptr},
{"real",
(getter)PropertyReal::getter,
(setter)THPVariable_set_real,
nullptr,
nullptr},
{"imag",
(getter)PropertyImag::getter,
(setter)THPVariable_set_imag,
nullptr,
nullptr},
{nullptr}};
THPVariable_get_grad_fn
여기서 check_has_torch_function()
이 참이면 grad_fn
을 설정하고, 만약 아니면 THPVariable_Unpack()
을 사용하여 None
을 return 하는 것 같다.
PyObject* THPVariable_get_grad_fn(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "grad_fn");
}
const auto& var = THPVariable_Unpack(self);
if (!var.grad_fn()) {
Py_RETURN_NONE;
}
return functionToPyObject(var.grad_fn());
END_HANDLE_TH_ERRORS
}
torch/csrc/utils/disable_torch_function.cpp
namespace torch {
auto check_has_torch_function(PyObject* obj, bool ignore_mode) -> bool {
if (!ignore_mode && at::impl::torch_function_mode_enabled())
return true;
PyTypeObject* tp = Py_TYPE(obj);
return (
!THPVariable_CheckTypeExact(tp) && !is_basic_python_type(tp) &&
torch::torch_function_enabled() && has_torch_function_attr(obj));
}
} // namespace torch
... 여기까지 알아보도록 하고 다음 포스팅에서 이어서 작성하겠다.