0%

『论文笔记』FDA - Fourier Domain Adaptation for Semantic Segmentation

Information

  • Title: FDA: Fourier Domain Adaptation for Semantic Segmentation
  • Author: Yanchao Yang
  • Institution: UCLA(加州大学洛杉矶分校)
  • Year: 2020
  • Journal:CVPR
  • Source: IEEE, arxiv, Github, Open access
  • Cite: Yanchao Yang, Stefano Soatto; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020, pp. 4085-4095
  • Idea: 用目标域的低频振幅替换源域做训练来近似域对齐
1
2
3
4
5
6
7
@InProceedings{Yang_2020_CVPR,
author = {Yang, Yanchao and Soatto, Stefano},
title = {FDA: Fourier Domain Adaptation for Semantic Segmentation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2020}
}

Abstract

提出了一种无监督域适应的方法,通过交换源域的低频成分来减少源域和目标域的分布差异。

Introduction

如图所示

illustrate_fftda

作者对源域和目标域图像做 FFT 并用目标域的低频部分替换源域的低频部分再使用 iFFT 获得重构图像用于训练。

有一个需要选择的超参数,即交换频谱区域的大小(图中绿色框),作者测试了各种大小以及一种多尺度方法。

作者提出这种方法的动机在于可以观察到振幅谱的低频部分可以显著变化而不会影响高级的语义信息,通过学习这种变化可以使得模型更加泛化。

Method

首先进行傅里叶变换,用 \(\mathcal{F}^{A}, \mathcal{F}^{P}: \mathbb{R}^{H \times W \times 3} \rightarrow \mathbb{R}^{H \times W \times 3}\) 表示傅里叶变换 \(\mathcal{F}\) 得到的振幅和相位,对于单通道图像 \(x\) 有: \[ \mathcal{F}(x)(m, n) = \sum_{h,w} x(h,w) e^{-j2\pi\left(\dfrac{h}{H}m + \dfrac{w}{W}n\right)}, j^2=-1 \]\(\mathcal{F}^{-1}\) 表示逆傅里叶变换用于将振幅和相位映射回空域,然后设置掩模矩阵 \(M_{\beta}\) 其除了中心区域外的位置都是 0,其中\(\beta \in (0, 1)\),假设图像的中心为 \((0,0)\)\[ M_{\beta}(h,w) = \mathbb{1}_{(h,w) \in [-\beta H: \beta H, -\beta W : \beta W]} \] \(\beta\) 不是基于像素进行选择的,所以不取决于图像大小和分辨率。

给定两个随机从源域和目标域采样的图像 \(x^s \sim D^s, x^t \sim D^t\),作者提出的 FDA 可以表示为: \[ x^{s\to t}=\mathcal{F}^{-1}([M_{\beta} \circ \mathcal{F}^A(x^t) + (1-M_{\beta}) \circ \mathcal{F}^A(x^s), \mathcal{F}^P(x^s)]) \] 其中将原图像低频部分的相位 \(\mathcal{F}^A(x^s)\) 替换为了目标图像。

differentLB

上图展示了 \(\beta\) 选择的不同带来的影响,作者设置 \(\beta \leq 0.15\).

应用到语义分割,目标损失函数为: \[ \mathcal{L}_{ce}(\phi^w; D^{s\to t}) = -\sum_i \langle y^s_i, \log( \phi^w(x^{s\to t}_i) )\rangle. \] 添加了正则化使其更鲁棒(没看懂) \[ \mathcal{L}_{ent}(\phi^w; D^{t}) = \sum_i \rho( -\langle \phi^w(x^{t}_i), \log( \phi^w(x^{t}_i) )\rangle ) \] 其中 \(\rho(x) = (x^2 + 0.001^2)^{\eta}\) 是Charbonnier惩罚函数,如下图所示,其对于高熵预测的惩罚大于低熵预测的惩罚。

cbnorm

最后使用整体的损失训练语义分割网络 \(\phi^w\)\[ \mathcal{L}(\phi^w; D^{s\to t}, D^t) = \mathcal{L}_{ce}(\phi^w; D^{s\to t}) + \lambda_{ent} \mathcal{L}_{ent}(\phi^w; D^{t}) \] 还有一些关于自监督训练的说明,但不是很懂,有兴趣的同学可以去看看原文。

Experiment

实验设置:

  • 数据集:合成数据集GTA5 和 SYNTHIA,真实域数据集 CityScapes
  • 分割网络:ResNet101 和 VGG16
  • 训练设置:BatchSize 为 1,SGD优化器,学习率 2.5e-4,poly 学习率调整策略(0.9 比例),0.0005 权重衰减

GTA5 \(\rightarrow\) CityScapes 任务的消融实验:

image-20230413160718152

GTA5 \(\rightarrow\) CityScapes 任务的量化比较:

image-20230413160902771

SYNTHIA \(\rightarrow\) CityScapes 任务的量化比较:

image-20230413160951708

如果从头开始训练,较大的 \(\beta\) 泛化效果更好,但与自监督训练相结合时会产生更多偏差

explain

可视化实验

visuals

Conclusion

作者提出了一种不需要任何学习的简单域对齐方法,并且可以轻松集成到将无监督域适应转换为半监督学习的学习系统中。需要注意损失函数的适当正则化,为此作者提出了具有各向异性(Charbonnier)加权的熵正则化器。多频带传输方案解决了自我监督训练中的自我参照问题,该方案不需要对具有复杂模型选择的学生网络进行联合训练。

这表明,由于低级统计数据导致的一些分布失调(众所周知,这种失调会对跨不同领域的泛化造成严重破坏)可以通过快速傅里叶变换非常简单地捕捉到。此外,实数信号频谱的傅里叶逆变换保证是实数,因为可以很容易地证明虚部在给定被积函数的斜对称性的情况下被抵消了。

对影响图像域的有害可变性的鲁棒性仍然是机器学习中的一个难题,作者并不声称我们的方法是最终的解决方案。然而,在某些情况下,可能没有必要学习已经知道的东西,例如图像的低级统计数据可以有很大差异而不影响底层场景的语义。这种预处理可以替代复杂的架构或费力的数据扩充。

Others

这个想法在我的研究中也被想到了,但并没有很好的实验结果,这篇论文值得研究和学习为什么同样的 idea 我没有做出成果呢?

源码分析

PS:重要的实现在 __init__.py 文件中还挺隐蔽

作者写了 numpy 和 pytorch 的两种实现,实际上用 pytorch 的好一些,因为可以使用 GPU 加速计算。

最开始是提取相位和振幅的函数(有点奇怪)

1
2
3
4
5
6
def extract_ampl_phase(fft_im):
# fft_im: size should be bx3xhxwx2
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
fft_amp = torch.sqrt(fft_amp)
fft_pha = torch.atan2( fft_im[:,:,:,:,1], fft_im[:,:,:,:,0] )
return fft_amp, fft_pha

接下来是个四个角上替换的函数

1
2
3
4
5
6
7
8
def low_freq_mutate( amp_src, amp_trg, L=0.1 ):
_, _, h, w = amp_src.size()
b = ( np.floor(np.amin((h,w))*L) ).astype(int) # get b
amp_src[:,:,0:b,0:b] = amp_trg[:,:,0:b,0:b] # top left
amp_src[:,:,0:b,w-b:w] = amp_trg[:,:,0:b,w-b:w] # top right
amp_src[:,:,h-b:h,0:b] = amp_trg[:,:,h-b:h,0:b] # bottom left
amp_src[:,:,h-b:h,w-b:w] = amp_trg[:,:,h-b:h,w-b:w] # bottom right
return amp_src

接下来这个函数应该是等价的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def low_freq_mutate_np( amp_src, amp_trg, L=0.1 ):
a_src = np.fft.fftshift( amp_src, axes=(-2, -1) )
a_trg = np.fft.fftshift( amp_trg, axes=(-2, -1) )

_, h, w = a_src.shape
b = ( np.floor(np.amin((h,w))*L) ).astype(int)
c_h = np.floor(h/2.0).astype(int)
c_w = np.floor(w/2.0).astype(int)

h1 = c_h-b
h2 = c_h+b+1
w1 = c_w-b
w2 = c_w+b+1

a_src[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2]
a_src = np.fft.ifftshift( a_src, axes=(-2, -1) )
return a_src

然后是两种实现方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def FDA_source_to_target(src_img, trg_img, L=0.1):
# exchange magnitude
# input: src_img, trg_img

# get fft of both source and target
fft_src = torch.rfft( src_img.clone(), signal_ndim=2, onesided=False )
fft_trg = torch.rfft( trg_img.clone(), signal_ndim=2, onesided=False )

# extract amplitude and phase of both ffts
amp_src, pha_src = extract_ampl_phase( fft_src.clone())
amp_trg, pha_trg = extract_ampl_phase( fft_trg.clone())

# replace the low frequency amplitude part of source with that from target
amp_src_ = low_freq_mutate( amp_src.clone(), amp_trg.clone(), L=L )

# recompose fft of source
fft_src_ = torch.zeros( fft_src.size(), dtype=torch.float )
fft_src_[:,:,:,:,0] = torch.cos(pha_src.clone()) * amp_src_.clone()
fft_src_[:,:,:,:,1] = torch.sin(pha_src.clone()) * amp_src_.clone()

# get the recomposed image: source content, target style
_, _, imgH, imgW = src_img.size()
src_in_trg = torch.irfft( fft_src_, signal_ndim=2, onesided=False, signal_sizes=[imgH,imgW] )

return src_in_trg

def FDA_source_to_target_np( src_img, trg_img, L=0.1 ):
# exchange magnitude
# input: src_img, trg_img

src_img_np = src_img #.cpu().numpy()
trg_img_np = trg_img #.cpu().numpy()

# get fft of both source and target
fft_src_np = np.fft.fft2( src_img_np, axes=(-2, -1) )
fft_trg_np = np.fft.fft2( trg_img_np, axes=(-2, -1) )

# extract amplitude and phase of both ffts
amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np)
amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np)

# mutate the amplitude part of source with target
amp_src_ = low_freq_mutate_np( amp_src, amp_trg, L=L )

# mutated fft of source
fft_src_ = amp_src_ * np.exp( 1j * pha_src )

# get the mutated image
src_in_trg = np.fft.ifft2( fft_src_, axes=(-2, -1) )
src_in_trg = np.real(src_in_trg)

return src_in_trg

实际上上面的代码会报错,应该是作者使用的是旧版 pytorch 的缘故,新版本的 pytorch 将 fft 集成到一个包 torch.fft 里面了而且并不会输出虚数实数分离的情况,所以若要运行需要安装合适的版本或者稍微进行修改


如有错漏,欢迎指正!如果对你有帮助的话,请给我点个赞吧~

欢迎前往 我的博客 查看更多笔记

--- ♥ end ♥ ---

欢迎关注我呀~