0%

Normalization源码解析及复现

前言

本文以 2D 图像数据为例对 BN(Batch Normalization), LN(Layer Normalization), IN(Instance Normalization), WN(Weight Normalization) 的 Pytorch 源码进行了分析解读。以期通过对源码的分析可以加深对这几种方法的理解。

我将首先简单分析一下 Pytorch 中对几种方法的 Python 实现,虽然深入探究一下底层 C++ 源码的实现,最后根据我们学习到的方法复现出 Pytorch 的实现效果。

下面这张来自 GN 的图片展示了几种方法的不同

all_norms

BN

BN 可以说是最早期的 Normalization 方法之一了,在现在的网络中有广泛应用,可以提高网络的学习效率,降低网络训练的不稳定性,但不太适用于生成任务,并且 batch size 的大小对 BN 的效果有很大影响。其计算公式为 \[ y = \dfrac{x - \mathbf{E}[x]}{\sqrt{\mathbf{Var}[x] + \epsilon}} * \gamma + \beta \] 即将输入张量归一化后再做一个仿射变换

BN 类

Pytorch 中 nn.BatchNorm2D 继承自基类 nn._BatchNorm(_NormBase), 拓展部分只有对维度的检测,我们直接看 BN 基类的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class _BatchNorm(_NormBase):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super(_BatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)

def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)

if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum

if self.training and self.track_running_stats:
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum

if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)

return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)

我们忽略参数 momentum 控制的指数移动平均,这里有两点学习,在训练或者设置 track_running_stats=False 的时候使用的是输入的均值和方差,否则使用的是 self.running_meanself.running_var。我们注意到 BN 中的几点:

  1. 最基本的 Normalization 操作是通过底层 C++ 函数实现的
  2. 添加 eps 防止方差太小导致的除零错
  3. 可选的仿射变换和动量
  4. 从官方文档中我们得知动量的计算方式为 \(x_{new} = (1 - momentum) \times \hat{x} + momentum \times x_t\), 其中 \(\hat{x}\) 是推测统计量而 \(x_t\) 是当前张量的统计值

再看看 _NormBase 这个基类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""

_version = 2
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool

def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()

def reset_running_stats(self) -> None:
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]

def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)

我们得知各个参数的维度:

  • 从官方文档中得知输入的 num_features 维度应该与 \((N, C, H, W)\) 的输入中的 \(C\) 同样大
  • 仿射变换的 \(weight\)\(bias\) 的维度都是与通道数 \(C\) 同样大,且初始化为 \(1, 0\).
  • running_meanrunning_var 都是与 \(C\) 一样大,且初始化为 \(0\)\(1\).

由此我们基本了解了 Batch Norm 的组成,但还有最关键的部分没有解决,而这部分是通过底层的 C++ 函数实现的,我们继续往下看

1
2
3
4
5
6
7
8
9
def batch_norm(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, training: bool = False, momentum: float = 0.1, eps: float = 1e-5, ) -> Tensor:
if has_torch_function_unary(input):
return handle_torch_function(batch_norm, (input,), input, running_mean, running_var, weight=weight, bias=bias, training=training, momentum=momentum, eps=eps, )
if training:
_verify_batch_size(input.size())

return torch.batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
)

源码解析

Pytorch - Github 中我们可以找到最基本的 C++ 实现

这部分源码看得有点懵,找来找去没找到关键的实现,不如还是直接看复现吧,源码也不是非要理解得极其透彻不可。

复现

我们写出如下的复现代码来测试一下结果是不是与我们设想的一致

1
2
3
4
5
6
7
x = torch.randn(32, 3, 224, 224).cuda()
print("BatchNorm:")
bn = nn.BatchNorm2d(3).cuda()
print(bn.running_mean, bn.running_var, bn.weight, bn.bias)
xbn = torch.batch_norm(x, bn.weight, bn.bias, bn.running_mean, bn.running_var, training=False, momentum=0., eps=bn.eps, cudnn_enabled=False)
ybn = (x - bn.running_mean.reshape(1, -1, 1, 1)) / torch.sqrt(bn.running_var + bn.eps).reshape(1, -1, 1, 1) * bn.weight.reshape(1, -1, 1, 1) + bn.bias.reshape(1, -1, 1, 1)
print(torch.allclose(xbn, ybn), torch.mean(xbn - ybn), xbn.mean(), ybn.mean(), '\n')

运行结果为 True 说明我们的理解是正确的,手动计算的与 BN 类的输出结果有微小误差,这是合理的。

1
2
3
4
5
BatchNorm:
tensor([0., 0., 0.], device='cuda:0') tensor([1., 1., 1.], device='cuda:0') Parameter containing:
tensor([1., 1., 1.], device='cuda:0', requires_grad=True) Parameter containing:
tensor([0., 0., 0.], device='cuda:0', requires_grad=True)
True tensor(2.1325e-09, device='cuda:0', grad_fn=<MeanBackward0>) tensor(-0.0399, device='cuda:0', grad_fn=<MeanBackward0>) tensor(-0.0399, device='cuda:0', grad_fn=<MeanBackward0>)

LN

LN 类

Layer Norm 的实现不涉及到 running meanrunning var 所以实现比较简单

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class LayerNorm(Module):
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool

def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self) -> None:
if self.elementwise_affine:
init.ones_(self.weight)
init.zeros_(self.bias)

def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)

def extra_repr(self) -> str:
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)

从上面的定义中我们了解到 weightbias 的大小都与输入的 normalized_shape 相同,而输入的 normalized_shape 的应该是 \((N, C, H, W)\) 中的 \((C, H, W)\), 并且初始化为 \(1\)\(0\). 而前向传播中调用同样是底层的 C++ 函数。

1
2
3
4
5
6
def layer_norm(input: Tensor, normalized_shape: List[int], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5, ) -> Tensor:
if has_torch_function_unary(input):
return handle_torch_function(
layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps
)
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)

复现

1
2
3
4
5
6
print("LayerNorm:")
ln = nn.LayerNorm(x.size()[1:]).cuda()
xln = torch.layer_norm(x, ln.normalized_shape, ln.weight, ln.bias, ln.eps, False)
print(ln.normalized_shape, ln.weight.size(), ln.bias.size(), ln.eps)
yln = (x - x.mean(dim=(1, 2, 3), keepdim=True,)) / torch.sqrt(x.var(dim=(1, 2, 3), unbiased=False, keepdim=True) + ln.eps) * ln.weight.unsqueeze(0) + ln.bias.unsqueeze(0)
print(torch.allclose(xln, yln), torch.mean(xln - yln), xln.mean(), yln.mean(), '\n')

输出结果与预期相符

1
2
3
LayerNorm:
(3, 224, 224) torch.Size([3, 224, 224]) torch.Size([3, 224, 224]) 1e-05
True tensor(1.3846e-10, device='cuda:0', grad_fn=<MeanBackward0>) tensor(-1.3178e-09, device='cuda:0', grad_fn=<MeanBackward0>) tensor(-1.3938e-09, device='cuda:0', grad_fn=<MeanBackward0>)

IN

IN 类

IN 基类如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class _InstanceNorm(_NormBase):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = False,
track_running_stats: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_InstanceNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)

def _check_input_dim(self, input):
raise NotImplementedError

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
# at version 1: removed running_mean and running_var when
# track_running_stats=False (default)
if version is None and not self.track_running_stats:
running_stats_keys = []
for name in ('running_mean', 'running_var'):
key = prefix + name
if key in state_dict:
running_stats_keys.append(key)
if len(running_stats_keys) > 0:
error_msgs.append(
'Unexpected running stats buffer(s) {names} for {klass} '
'with track_running_stats=False. If state_dict is a '
'checkpoint saved before 0.4.0, this may be expected '
'because {klass} does not track running stats by default '
'since 0.4.0. Please remove these keys from state_dict. If '
'the running stats are actually needed, instead set '
'track_running_stats=True in {klass} to enable them. See '
'the documentation of {klass} for details.'
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
klass=self.__class__.__name__))
for key in running_stats_keys:
state_dict.pop(key)

super(_InstanceNorm, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
return F.instance_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)

其输入的 num_featuresBN 相同都是 \((N,C,H,W)\) 中的 \(C\),值得注意的是 IN 的基类与 BN 的基类同样是继承自 _NormBase, 而且 IN 的接口与 BN 基本相同,这是有原因的,在前向传播中同样是调用底层的 C++ 函数,该函数实现其实并不复杂。

源码解析

关键的代码是第 19 行,即调用了 BN 的接口函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
Tensor instance_norm(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool use_input_stats, double momentum, double eps, bool cudnn_enabled) {
TORCH_CHECK(use_input_stats || (running_mean.defined() && running_var.defined()),
"Expected running_mean and running_var to be defined when use_input_stats is false");
std::vector<int64_t> shape = input.sizes().vec();
int64_t b = input.size(0);
int64_t c = input.size(1);
shape[1] = b * c;
shape[0] = 1;

Tensor weight_ = repeat_if_defined(weight, b);
Tensor bias_ = repeat_if_defined(bias, b);
Tensor running_mean_ = repeat_if_defined(running_mean, b);
Tensor running_var_ = repeat_if_defined(running_var, b);

auto input_reshaped = input.contiguous().view(shape);
auto out = at::batch_norm(input_reshaped, weight_, bias_, running_mean_, running_var_,
use_input_stats, momentum, eps, cudnn_enabled);

// we alias running_mean and running_var because they are const but we want to modify their data
if (running_mean.defined()) {
at::alias(running_mean).copy_(running_mean_.view({ b, c }).mean(0, false));
}
if (running_var.defined()) {
at::alias(running_var).copy_(running_var_.view({ b, c }).mean(0, false));
}

return out.view(input.sizes());
}

简单来说,就是将 \((N,C,H,W)\) 的输入调整大小为 \((1, N \times C, H, W)\) 然后传入 BN 接口进行计算,验证如下

复现

1
2
3
4
5
6
print("InstanceNorm:")
In = nn.InstanceNorm2d(3).cuda()
print(In.running_mean, In.running_var, In.weight, In.bias)
xIn = torch.instance_norm(x, In.weight, In.bias, In.running_mean, In.running_var, True, 0, In.eps, False)
yIn = torch.batch_norm(x.reshape(1, -1, *x.size()[2:]), In.weight, In.bias, In.running_mean, In.running_var, True, 0, In.eps, False).reshape_as(x)
print(torch.allclose(xIn, yIn), torch.mean(xIn - yIn), xIn.mean(), yIn.mean(), '\n')

因为借助了同样调用了 C++ 接口的 batch_norm 的实现,所以得到的结果也是完全一样的

1
2
3
InstanceNorm:
None None None None
True tensor(0., device='cuda:0') tensor(2.2808e-10, device='cuda:0') tensor(2.2808e-10, device='cuda:0')

WN

WN 类

Pytorch 中 Weight Normalization 的官方实现是 torch.nn.utils.weight_norm(module, name, dim).

先放个源码,后面再补充详细解析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class WeightNorm(object):
name: str
dim: int

def __init__(self, name: str, dim: int) -> None:
if dim is None:
dim = -1
self.name = name
self.dim = dim

# TODO Make return type more specific
def compute_weight(self, module: Module) -> Any:
g = getattr(module, self.name + '_g')
v = getattr(module, self.name + '_v')
return _weight_norm(v, g, self.dim)

@staticmethod
def apply(module, name: str, dim: int) -> 'WeightNorm':
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError("Cannot register two weight_norm hooks on "
"the same parameter {}".format(name))

if dim is None:
dim = -1

fn = WeightNorm(name, dim)

weight = getattr(module, name)
if isinstance(weight, UninitializedParameter):
raise ValueError(
'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
'Make sure to run the dummy forward before applying weight normalization')
# remove w from parameter list
del module._parameters[name]

# add g and v as new parameters and express w as g/||v|| * v
module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
module.register_parameter(name + '_v', Parameter(weight.data))
setattr(module, name, fn.compute_weight(module))

# recompute weight before every forward()
module.register_forward_pre_hook(fn)

return fn

def remove(self, module: Module) -> None:
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_g']
del module._parameters[self.name + '_v']
setattr(module, self.name, Parameter(weight.data))

def __call__(self, module: Module, inputs: Any) -> None:
setattr(module, self.name, self.compute_weight(module))


T_module = TypeVar('T_module', bound=Module)

def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
r"""Applies weight normalization to a parameter in the given module.

.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}

Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.

By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.

See https://arxiv.org/abs/1602.07868

Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm

Returns:
The original module with the weight norm hook

Example::

>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])

"""
WeightNorm.apply(module, name, dim)
return module


def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
r"""Removes the weight normalization reparameterization from a module.

Args:
module (Module): containing module
name (str, optional): name of weight parameter

Example:
>>> m = weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module

raise ValueError("weight_norm of '{}' not found in {}"
.format(name, module))

源码解析

先放个源码https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten/src/ATen/native/WeightNorm.cpp,后面再补充

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim)
{
// I assume tensor.contiguous(), view(), norm(), etc. here will dispatch through VariableType.
if (dim == -1) {
return v.norm(pow);
} else if (dim == 0) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[0] = v.size(0);
return v.contiguous().view({v.size(0), -1}).norm(pow, 1).view(output_size);
} else if (dim == v.dim() - 1) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[v.dim() - 1] = v.size(v.dim() - 1);
return v.contiguous().view({-1, v.size(v.dim() - 1)}).norm(pow, 0).view(output_size);
} else {
// To consider: at::native::norm_except_dim is probably fine as well,
// and would avoid an additional dynamic dispatch.
return at::norm_except_dim(v.transpose(0, dim), pow, 0).transpose(0, dim); // optimize?
}
}

Tensor _weight_norm
(const Tensor & v_in,
const Tensor & g_in,
int64_t dim)
{

TORCH_CHECK(
v_in.device() == g_in.device(),
"weight_norm: expected v_in and g_in to be on the same device, but v_in is "
"on ", v_in.device(), " and g_in is on ", g_in.device());

auto v = v_in.contiguous();
auto g = g_in.contiguous();

bool can_use_fused = v.is_cuda() && (dim == 0 || dim == v.dim() - 1);

if (can_use_fused) {
// weight_norm does not have a derivative defined for it, so this will route back through
// VariableType.cpp, and construct a WeightNormFusedBackward object in the autograd graph.
return std::get<0>(at::_weight_norm_cuda_interface(v, g, dim));
} else {
// Double-differentiable primitive ops
// at::native::norm_except_dim would probably be fine as well.
return v*(g/at::norm_except_dim(v, 2, dim));
}
}

// Differentiable backward path, an alternative to weight_norm_cuda_backward, to be used
// when backward is itself creating a graph.
// The GradMode::is_enabled() check must be performed within Functions.cpp; that's why we
// define a separate function here, instead of inlining it in weight_norm_cuda_backward.
std::tuple<Tensor, Tensor> _weight_norm_differentiable_backward
(const Tensor & grad_w,
const Tensor & saved_v,
const Tensor & saved_g,
const Tensor & saved_norms,
int64_t dim)
{
// In Functions.cpp, the HardshrinkBackward object supplies "grad.contiguous()"
// as the first argument, so grad_w should be contiguous here.
// All these checks should succeed:
TORCH_CHECK(grad_w.is_contiguous(), "grad_w must be contiguous");
TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");

int64_t last_dim = saved_v.dim() - 1;
int64_t last_size = saved_v.size(last_dim);

// Like weight_norm_fused_backward, weight_norm_differentiable_backward should only ever be called
// through a WeightNormFusedBackward object, so we expect that dim == 0 || dim == saved_v.size(-1)
TORCH_CHECK(dim == 0 || dim == last_dim, "Expected dim to be the first or last dimension");

// saved_g and saved_norms are already shaped to broadcast over the correct dimensions

// ...but saved_norms might be Float when saved_g and saved_v are half.
// To consider: saved_norms.to(..., True /*non_blocking*/);
auto norms = saved_norms.to(saved_g.scalar_type());

std::vector<int64_t> bcast_size(saved_v.dim(), 1);

// Analytic backward path using differentiable primitive ops
if (dim == 0) {
bcast_size[0] = saved_v.size(0);
auto per_dim_sums = (grad_w*saved_v).view({saved_v.size(0), -1}).sum(1).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
} else { // dim == last_dim
bcast_size[last_dim] = last_size;
auto per_dim_sums = (grad_w*saved_v).view({-1, last_size}).sum(0).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
}
}

参考资料

--- ♥ end ♥ ---

欢迎关注我呀~