0%

『论文笔记』Switchable Normalization

Information

  • Title: Differentiable Learning-to-Normalize via Switchable Normalization
  • Author: Ping Luo and Jiamin Ren and Zhanglin Peng and Ruimao Zhang and Jingyu Li
  • Institution: 香港中文大学(CUHK)多媒体实验室和商汤科技研究院(SenseTime Research)
  • Year: 2019
  • Journal: ICLR
  • Source: arxiv, Code, 作者的中文解析(在 Github 上)
  • Idea: 核心观点就是同时使用 BN, IN, LN 并对其添加一个自适应的权重来适配不同的网络
1
2
3
4
5
6
@article{luo2018differentiable,
title={Differentiable Learning-to-Normalize via Switchable Normalization},
author={Ping Luo and Jiamin Ren and Zhanglin Peng and Ruimao Zhang and Jingyu Li},
journal={International Conference on Learning Representation (ICLR)},
year={2019}
}

Abstract

Switchable Normalization (SN) 从 channel, layer, minibatch 三个维度计算统计量,并使用端到端的方法获得权重。SN 有三个优点:

  1. 自适应多种网络结构和任务
  2. 对相当大范围的 batch size 有鲁棒性
  3. 没有超参数

Introduction

SN 可以自适应的选择 BN, IN, LN 的的权重,所以在很多任务中都能比较好的发挥作用。

hist_task_v3

文章的贡献有:

  1. 提出了 SN,在很多机器学习的任务中都有用
  2. 可解释
  3. 让各种 Norm 方法在深度网络中都能发挥作用

Method

BN, IN, LN 基本思路是一致的,只是操作对象(具体说是目标样本的维度)不同,就不多解释了。

SN 的公式可以写作: \[ \hat{h}_{ncij}=\gamma\frac{h_{ncij}-\Sigma_{k\in\Omega} w_k\mu_k} {\sqrt{\Sigma_{k\in\Omega}w_k^\prime\sigma_k^2+\epsilon}}+\beta \] 其实也和 BN, IN, LN 是类似的,只不过目标对象是三种方法的加权和,三种对象拆解开就是: \[ \begin{eqnarray} \mu_{\mathrm{in}}&=&\frac{1}{HW}\sum_{i,j}^{H,W}h_{ncij}, ~~~~~~\sigma^2_{\mathrm{in}}=\frac{1}{HW}\sum_{i,j}^{H,W}(h_{ncij}-\mu_{\mathrm{in}})^2,\nonumber\\ \mu_{\mathrm{ln}}&=&\frac{1}{C}\sum_{c=1}^C\mu_{\mathrm{in}},~~~~~~\sigma^2_{\mathrm{ln}}= \frac{1}{C}\sum_{c=1}^C(\sigma^2_{\mathrm{in}}+\mu_{\mathrm{in}}^2)-\mu_{\mathrm{ln}}^2,\nonumber\\ \mu_{\mathrm{bn}}&=&\frac{1}{N}\sum_{n=1}^N\mu_{\mathrm{in}},~~~~~~\sigma^2_{\mathrm{bn}}= \frac{1}{N}\sum_{n=1}^N(\sigma^2_{\mathrm{in}}+\mu_{\mathrm{in}}^2)-\mu_{\mathrm{bn}}^2, \end{eqnarray} \] 上面这个公式其实很清晰的表达了三种Normalization 的操作对象了,另一方面,LN 和 BN 可以基于 IN 计算,这样做可以减少计算量。算起来总共有 6 个权重参数需要计算,并且均值和方差的权重和严格为 1,可以通过 softmax 保证这一点: \[ w_k=\frac{e^{\lambda_k}}{\Sigma_{z\in\{\mathrm{in},\mathrm{ln},\mathrm{bn}\}}e^{\lambda_z}}~ ~\mathrm{and}~~k\in\{\mathrm{in},\mathrm{ln},\mathrm{bn}\}. \]

和其他方法的比较

image-20221011111258659

Detail

作者通过使用 WN (Weight Normalization) 来将 SN 和 BN, IN, LN 进行比较

这个我有点迷惑,为什么又涉及到 WN 了?

依据后文和我个人的理解,是将对特征的规范投射到权重即 WN,然后将两个不同的权重 \(\mathbf{w}_1, \mathbf{w}_2\) 做比较

作者的另外一篇文章:Towards understanding regularization in batch normalization(2019, ICLR) 有指出 WN 是 BN 的一种特殊形式

改写为:

\[ \hat{h}_{\mathrm{in}}=\gamma\frac{\mathbf{w}_i^T\mathbf{x}}{\|\mathbf{w}_i\|_2}+\beta \\ \hat{h}_{\mathrm{bn}}=\gamma\frac{\mathbf{w}_i^T\mathbf{x}}{\|\mathbf{w}_i\|_2}+\beta,~\mathrm{s.t.}~\gamma\leq v \\ \hat{h}_{\mathrm{ln}}=\gamma\frac{\mathbf{w}_i^T\mathbf{x}}{\|\mathbf{w}_i\|_2+\sum_{j\neq i}^C\|\mathbf{w}_j\|_2}+\beta \]

SN 表示为:

\[ \hat{h}_{\mathrm{sn}}=w_{\mathrm{in}}\hat{h}_{\mathrm{in}}+w_{\mathrm{bn}}\hat{h}_{\mathrm{bn}}+w_{\mathrm{ln}}\hat{h}_{\mathrm{ln}}=\gamma\frac{\mathbf{w}_i^T\mathbf{x}}{\|\mathbf{w}_i\|_2+w_{\mathrm{ln}}\sum_{j\neq i}^C\|\mathbf{w}_j\|_2}+\beta,~\mathrm{s.t.}~w_{\mathrm{bn}}\gamma\leq v \]

下图是与 WN 的几何图,作者说明的是:IN 和 WN 很像,都是将范数规范到 1 然后缩放到 \(\gamma\), 而 BN 要小一些同时会使角度变大,并使 \(\mathbf{w}_1, \mathbf{w}_2\) 同步。LN 约束比较弱,可以有 \(\gamma > v\), 而 SN 继承了所有方法能达到一个平衡。

geometry

这里讲得有点迷惑,不是特别清晰,这个结论怎么来的?

一些变体

  • Sparsity SN: 通过 softmax 选择其中一种 SN
  • Group SN: 通道分组,每组选择 SN,不过是不是 稀疏SN 文章没说

具体实现中有一个比较值得注意的点是 BN 的移动平均。

Experiment

做了很多实验,看 Introduction 的那个图就知道了,主要针对的还是小 batch size 多 GPU 的一些实验,就不多赘述了,看图

image-20221011115233865

image-20221011115249716

image-20221011115302978

Conclusion

SN 在不同的 Normalization 层可以由不同的操作(权重),其应用广泛。后续研究表明 SN 可以平衡学习与泛化。

总的来说,SN是一种任务与数据驱动的归一化方法。它对各种归一化技术对训练所产生的影响进行激励或者抑制,从而取得最优性能。本文作者的团队通过对各种归一化方法进行深入的分析,发现SN学习得到的权重系数能够准确的反映各种归一化技术在不同训练条件下的理论性质。相信随着这一研究的不断深入,这些数学性质将会被更清楚的描述出来。

Others

代码部分,可以开箱即用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import torch
import torch.nn as nn

class SwitchNorm1d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True):
super(SwitchNorm1d, self).__init__()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.weight = nn.Parameter(torch.ones(1, num_features))
self.bias = nn.Parameter(torch.zeros(1, num_features))
self.mean_weight = nn.Parameter(torch.ones(2))
self.var_weight = nn.Parameter(torch.ones(2))
self.register_buffer('running_mean', torch.zeros(1, num_features))
self.register_buffer('running_var', torch.zeros(1, num_features))
self.reset_parameters()

def reset_parameters(self):
self.running_mean.zero_()
self.running_var.zero_()
self.weight.data.fill_(1)
self.bias.data.zero_()

def _check_input_dim(self, input):
if input.dim() != 2:
raise ValueError('expected 2D input (got {}D input)'
.format(input.dim()))

def forward(self, x):
self._check_input_dim(x)
mean_ln = x.mean(1, keepdim=True)
var_ln = x.var(1, keepdim=True)

if self.training:
mean_bn = x.mean(0, keepdim=True)
var_bn = x.var(0, keepdim=True)
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)

softmax = nn.Softmax(0)
mean_weight = softmax(self.mean_weight)
var_weight = softmax(self.var_weight)

mean = mean_weight[0] * mean_ln + mean_weight[1] * mean_bn
var = var_weight[0] * var_ln + var_weight[1] * var_bn

x = (x - mean) / (var + self.eps).sqrt()
return x * self.weight + self.bias

class SwitchNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.9, using_moving_average=True, using_bn=True,
last_gamma=False):
super(SwitchNorm2d, self).__init__()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.using_bn = using_bn
self.last_gamma = last_gamma
self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
if self.using_bn:
self.mean_weight = nn.Parameter(torch.ones(3))
self.var_weight = nn.Parameter(torch.ones(3))
else:
self.mean_weight = nn.Parameter(torch.ones(2))
self.var_weight = nn.Parameter(torch.ones(2))
if self.using_bn:
self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
self.register_buffer('running_var', torch.zeros(1, num_features, 1))

self.reset_parameters()

def reset_parameters(self):
if self.using_bn:
self.running_mean.zero_()
self.running_var.zero_()
if self.last_gamma:
self.weight.data.fill_(0)
else:
self.weight.data.fill_(1)
self.bias.data.zero_()

def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))

def forward(self, x):
self._check_input_dim(x)
N, C, H, W = x.size()
x = x.view(N, C, -1)
mean_in = x.mean(-1, keepdim=True)
var_in = x.var(-1, keepdim=True)

mean_ln = mean_in.mean(1, keepdim=True)
temp = var_in + mean_in ** 2
var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2

if self.using_bn:
if self.training:
mean_bn = mean_in.mean(0, keepdim=True)
var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)

softmax = nn.Softmax(0)
mean_weight = softmax(self.mean_weight)
var_weight = softmax(self.var_weight)

if self.using_bn:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
else:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln
var = var_weight[0] * var_in + var_weight[1] * var_ln

x = (x-mean) / (var+self.eps).sqrt()
x = x.view(N, C, H, W)
return x * self.weight + self.bias


class SwitchNorm3d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True, using_bn=True,
last_gamma=False):
super(SwitchNorm3d, self).__init__()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.using_bn = using_bn
self.last_gamma = last_gamma
self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1, 1))
if self.using_bn:
self.mean_weight = nn.Parameter(torch.ones(3))
self.var_weight = nn.Parameter(torch.ones(3))
else:
self.mean_weight = nn.Parameter(torch.ones(2))
self.var_weight = nn.Parameter(torch.ones(2))
if self.using_bn:
self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
self.register_buffer('running_var', torch.zeros(1, num_features, 1))

self.reset_parameters()

def reset_parameters(self):
if self.using_bn:
self.running_mean.zero_()
self.running_var.zero_()
if self.last_gamma:
self.weight.data.fill_(0)
else:
self.weight.data.fill_(1)
self.bias.data.zero_()

def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))

def forward(self, x):
self._check_input_dim(x)
N, C, D, H, W = x.size()
x = x.view(N, C, -1)
mean_in = x.mean(-1, keepdim=True)
var_in = x.var(-1, keepdim=True)

mean_ln = mean_in.mean(1, keepdim=True)
temp = var_in + mean_in ** 2
var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2

if self.using_bn:
if self.training:
mean_bn = mean_in.mean(0, keepdim=True)
var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)

softmax = nn.Softmax(0)
mean_weight = softmax(self.mean_weight)
var_weight = softmax(self.var_weight)

if self.using_bn:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
else:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln
var = var_weight[0] * var_in + var_weight[1] * var_ln

x = (x - mean) / (var + self.eps).sqrt()
x = x.view(N, C, D, H, W)
return x * self.weight + self.bias

References

如果对你有帮助的话,请给我点个赞吧~

欢迎前往 我的博客 查看更多笔记

--- ♥ end ♥ ---

欢迎关注我呀~