通俗易懂理解小波池化/WaveCNet

重要说明:本文从网上资料整理而来,仅记录博主学习相关知识点的过程,侵删。

一、参考资料

github代码:WaveCNet

通俗易懂理解小波变换(Wavelet Transform)

二、相关介绍

关于小波变换的详细介绍,请参考另一篇博客:通俗易懂理解小波变换(Wavelet Transform)

1. DWT和IDWT原理

小波变换是可逆的,小波变换可以通过小波分解和重构,恢复原始图像详细。

在这里插入图片描述

对于输入图像

I

I

I,进行两级小波变换,可以得到:

L

L

2

,

(

L

H

2

,

H

L

2

,

H

H

2

)

,

(

L

H

1

,

H

L

1

,

H

H

1

)

=

D

W

T

(

D

W

T

(

I

)

)

LL2, (LH2, HL2, HH2), (LH1, HL1, HH1) = DWT(DWT(I))

LL2,(LH2,HL2,HH2),(LH1,HL1,HH1)=DWT(DWT(I))
舍弃最高频的子带LH1, HL1和HH1,保留相对低频的LL2, (LH2, HL2, HH2)。最后对保留的二级小波系数进行逆变换,重构图像:

I

=

I

D

W

T

(

L

L

2

,

L

H

2

,

H

L

2

,

H

H

2

)

I^{prime}= IDWT(LL2, LH2, HL2, HH2)

I′=IDWT(LL2,LH2,HL2,HH2)

三、小波池化

Wavelet Pooling小波池化的思考
小波变换和曲波变换用于池化层

以文献[1-2]为例,详细介绍小波池化。

1. 引言

池化是舍弃信息来实现正则化的效果。传统的 Max PoolingAverage Pooling都有一些局限性。Max pooling 是一个有效的池化方法,但可能过于简单;Average Pooling会产生模糊。当主要的特征幅度值低于不重要的特征时,重要的特征在max pooling中就丢失了。而Average Pooling接收了幅值大的特征和幅值小的特征,会稀释幅值大的特征。具体如下图所示:

在这里插入图片描述

并且,Average Pooling或Max Pooling是不可逆的。一旦进行平均池化或者最大池化,新的特征空间无法保留原先特征空间的所有信息。而小波池化是可逆的,能恢复所有的原始特征。

2. DWT与IDWT网络层

对于DWT和IDWT网络层的关键在于数据的前向传播(forward propagations)和后向传播(backward propagations)。本章节以1D正交小波和1D数据为例,分析DWT和IDWT。同理,可以推广到其他小波和2D/3D数据,只有细微的变化。

2.1 前向传播(Forward propagation)

对于1D数据

x

=

{

s

j

}

j

Z

mathbf{x}={s_{j}}_{jinmathbb{Z}}

x={sj?}j∈Z?,通过DWT的低通滤波(low-pass filters)分解为低频成分

x

l

o

w

=

{

x

k

(

l

o

w

)

}

k

Z

mathbf{x}_mathrm{low}={mathbf x_{k}^{(mathrm{low})}}_{kinmathbb{Z}}

xlow?={xk(low)?}k∈Z?,通过DWT的高通滤波(high-pass filters)分解为高频成分

x

h

i

g

h

=

{

x

k

(

h

i

g

h

)

}

k

Z

mathbf{x}_mathrm{high}={mathbf x_{k}^{(mathrm{high})}}_{kinmathbb{Z}}

xhigh?={xk(high)?}k∈Z?。

{

x

k

(

l

o

w

)

=

j

l

j

?

2

k

x

j

,

x

k

(

h

i

g

h

)

=

j

h

j

?

2

k

x

j

,

(

1

)

left.left{egin{array}{c}mathbf x_{k}^{(mathrm{low})}=sum_jl_{j-2k}x_j,\mathbf x_{k}^{(mathrm{high})}=sum_jh_{j-2k}x_j,end{array}
ight.
ight. quad (1)

{xk(low)?=∑j?lj?2k?xj?,xk(high)?=∑j?hj?2k?xj?,?(1)
其中,

l

=

{

l

k

}

k

Z

oldsymbol{l}={l_{k}}_{kinmathbb{Z}}

l={lk?}k∈Z? ,

h

=

{

h

k

}

k

Z

oldsymbol{h}={h_{k}}_{kinmathbb{Z}}

h={hk?}k∈Z? 分别表示正交小波(orthogonal wavelet)的低通滤波(low-pass filters)和高通滤波(high-pass filters)。由

公式

(

1

)

公式(1)

公式(1) 可知,DWT包含两个过程:过滤下采样

IDWT使用

s

l

o

w

,

s

h

i

g

h

mathbf{s}_mathrm{low},mathbf{s}_mathrm{high}

slow?,shigh? 重构

s

mathbf{s}

s,公式表达如下:

x

j

=

k

(

l

j

?

2

k

x

k

(

l

o

w

)

+

h

j

?

2

k

x

k

(

h

i

g

h

)

)

.

(

2

)

x_j=sum_kleft(l_{j-2k}mathbf x_{k}^{(mathrm{low})}+h_{j-2k}mathbf x_{k}^{(mathrm{high})}
ight). quad (2)

xj?=k∑?(lj?2k?xk(low)?+hj?2k?xk(high)?).(2)
用矩阵和向量表示,

公式

(

1

)

公式(1)

公式(1) 和

公式

(

2

)

公式(2)

公式(2) 可以重写为:

x

l

o

w

=

L

x

,

x

h

i

g

h

=

H

x

,

(

3

)

x

=

L

T

x

l

o

w

+

H

T

x

h

i

g

h

,

(

4

)

egin{aligned}mathbf{x}_mathrm{low}&=mathbf{L}mathbf{x},quadmathbf{x}_mathrm{high}=mathbf{H}mathbf{x},quad&(3)\mathbf{x}&=mathbf{L}^Tmathbf{x}_mathrm{low}+mathbf{H}^Tmathbf{x}_mathrm{high},quad&(4)end{aligned}

xlow?x?=Lx,xhigh?=Hx,=LTxlow?+HTxhigh?,?(3)(4)?
其中

L

=

(

?

?

?

?

l

?

1

l

0

l

1

?

?

l

?

1

l

0

l

1

?

?

?

)

,

(

5

)

left.mathbf{L}=left(egin{array}{ccccccc}cdots&cdots&cdots&&&&\cdots&l_{-1}&l_0&l_1&cdots&&\&cdots&l_{-1}&l_0&l_1&cdots\&&&&cdots&cdotsend{array}
ight.
ight),quad(5)

L=
?????l?1????l0?l?1??l1?l0???l1??????
?,(5)

H

=

(

?

?

?

?

h

?

1

h

0

h

1

?

?

h

?

1

h

0

h

1

?

?

?

)

.

(

6

)

left.mathbf{H}=left(egin{array}{ccccccc}cdots&cdots&cdots&&&\cdots&h_{-1}&h_0&h_1&cdots&\&&cdots&h_{-1}&h_0&h_1&cdots\&&&&cdots&cdotsend{array}
ight.
ight).(6)

H=
?????h?1???h0???h1?h?1???h0???h1?????
?.(6)

对于2D数据

X

mathbf{X}

X,2D DWT通常对其每一行(row) 和每一列(column)进行1D DWT操作,也就是:

X

l

l

=

L

X

L

T

,

(

7

)

X

l

h

=

H

X

L

T

,

(

8

)

X

h

l

=

L

X

H

T

,

(

9

)

X

h

h

=

H

X

H

T

,

(

10

)

egin{gathered} mathbf{X}_{ll} =mathbf{L}mathbf{X}mathbf{L}^{T}, left(7
ight) \ mathbf{X}_{lh} =mathbf{HXL}^{T}, left(8
ight) \ mathbf{X}_{hl} =mathbf{LXH}^{T}, left(9
ight) \ mathbf{X}_{hh} =mathbf{HXH}^{T}, left(10
ight) end{gathered}

Xll?=LXLT,(7)Xlh?=HXLT,(8)Xhl?=LXHT,(9)Xhh?=HXHT,(10)?
对于2D IDWT操作,公式表达如下:

X

=

L

T

X

l

l

L

+

H

T

X

l

h

L

+

L

T

X

h

l

H

+

H

T

X

h

h

H

.

(

11

)

mathbf{X}=mathbf{L}^Tmathbf{X}_{ll}mathbf{L}+mathbf{H}^Tmathbf{X}_{lh}mathbf{L}+mathbf{L}^Tmathbf{X}_{hl}mathbf{H}+mathbf{H}^Tmathbf{X}_{hh}mathbf{H}.quad(11)

X=LTXll?L+HTXlh?L+LTXhl?H+HTXhh?H.(11)

在 2D DWT的输出中,

X

l

l

mathbf{X}_{ll}

Xll? 是输入

X

mathbf{X}

X的低频成分,代表主要的信息,包括目标的基本结构;对应的

X

l

h

,

X

h

l

,

X

h

h

mathbf{X}_{lh}, mathbf{X}_{hl}, mathbf{X}_{hh}

Xlh?,Xhl?,Xhh? 是三个高频成分,其保存了输入

X

mathbf{X}

X水平(horizontal)、垂直(vertical)、对角线(diagonal)的细节信息。

2.2 反向传播(Backward propagation)

公式

(

3

)

?

(

4

)

(3)-(4)

(3)?(4) 表示1D DWT和IDWT的前向传播。1D DWT的反向传播与梯度

?

x

l

o

w

?

x

frac{partial mathbf{x}_mathrm{low}}{partial mathbf{x}}

?x?xlow??和

?

x

h

i

g

h

?

x

frac{partial mathbf{x}_mathrm{high}}{partial mathbf{x}}

?x?xhigh??密切相关,可以从公式(3) 中推导出:

?

x

l

o

w

?

x

=

L

T

,

?

x

h

i

g

h

?

x

=

H

T

.

(

12

)

frac{partial mathbf{x}_mathrm{low}}{partial mathbf{x}} = mathbf{L}^T, frac{partial mathbf{x}_mathrm{high}}{partial mathbf{x}} = mathbf{H}^T.quad(12)

?x?xlow??=LT,?x?xhigh??=HT.(12)

类似的,1D IDWT的反向传播与梯度

?

x

l

o

w

?

x

frac{partial mathbf{x}_mathrm{low}}{partial mathbf{x}}

?x?xlow??和

?

x

h

i

g

h

?

x

frac{partial mathbf{x}_mathrm{high}}{partial mathbf{x}}

?x?xhigh??密切相关,可以从公式(4) 中推导出:

?

x

?

x

l

o

w

=

L

,

?

x

?

x

h

i

g

h

=

H

.

(

13

)

frac{partial mathbf{x}}{partial mathbf{x}_mathrm{low}} = mathbf{L}, frac{partial mathbf{x}}{partial mathbf{x}_mathrm{high}} = mathbf{H}.quad(13)

?xlow??x?=L,?xhigh??x?=H.(13)

对于2D DWT的反向传播可以通过梯度

?

X

l

l

?

X

(

G

)

frac{partial X_{ll}}{partial X}(G)

?X?Xll??(G),

?

X

l

h

?

X

(

G

)

frac{partial X_{lh}}{partial X}(G)

?X?Xlh??(G),

?

X

h

l

?

X

(

G

)

frac{partial X_{hl}}{partial X}(G)

?X?Xhl??(G),

?

X

h

h

?

X

(

G

)

frac{partial X_{hh}}{partial X}(G)

?X?Xhh??(G)实现,公式表达如下:

?

X

l

l

?

X

(

G

)

=

L

T

G

L

,

(

14

)

?

X

l

h

?

X

(

G

)

=

H

T

G

L

,

(

15

)

?

X

h

l

?

X

(

G

)

=

L

T

G

H

,

(

16

)

?

X

h

h

?

X

(

G

)

=

H

T

G

H

,

(

17

)

egin{gathered} frac{partial X_{ll}}{partial X}(G) ={L}^{T}G{L}, quad (14) \ frac{partialoldsymbol{X}_{lh}}{partialoldsymbol{X}}(G) ={H}^{T}G{L}, quad (15) \ frac{partialoldsymbol{X}_{hl}}{partialoldsymbol{X}}(G) ={L}^{T}G{H}, quad (16) \ frac{partial X_{hh}}{partial X}(G) ={H}^{T}G{H}, quad (17) end{gathered}

?X?Xll??(G)=LTGL,(14)?X?Xlh??(G)=HTGL,(15)?X?Xhl??(G)=LTGH,(16)?X?Xhh??(G)=HTGH,(17)?
其中,

G

G

G 是2D DWT之后的层的反向传播输出。

类似的,对于2D IDWT的反向传播可以通过

?

X

?

X

l

l

(

G

)

frac{partial X}{partial X_{ll}}(G)

?Xll??X?(G),

?

X

?

X

l

h

(

G

)

frac{partial X}{partial X_{lh}}(G)

?Xlh??X?(G),

?

X

?

X

h

l

(

G

)

frac{partial X}{partial X_{hl}}(G)

?Xhl??X?(G),

?

X

?

X

h

h

(

G

)

frac{partial X}{partial X_{hh}}(G)

?Xhh??X?(G)实现,公式表达如下:

?

X

?

X

l

l

(

G

)

=

L

G

L

T

,

(

18

)

?

X

?

X

l

h

(

G

)

=

H

G

L

T

,

(

19

)

?

X

?

X

h

l

(

G

)

=

L

G

H

T

,

(

20

)

?

X

?

X

h

h

(

G

)

=

H

G

H

T

,

(

21

)

egin{gathered} frac{partialoldsymbol{X}}{partialoldsymbol{X}_{ll}}(G)={L}G{L}^{T}, quad (18) \ frac{partialoldsymbol{X}}{partialoldsymbol{X}_{lh}}(G)={H}G{L}^{T}, quad (19) \ frac{partialoldsymbol{X}}{partialoldsymbol{X}_{hl}}(G)={L}G{H}^{T}, quad (20) \ frac{partialoldsymbol{X}}{partialoldsymbol{X}_{hh}}(G)={H}G{H}^{T}, quad (21) end{gathered}

?Xll??X?(G)=LGLT,(18)?Xlh??X?(G)=HGLT,(19)?Xhl??X?(G)=LGHT,(20)?Xhh??X?(G)=HGHT,(21)?
其中,

G

G

G 是2D IDWT之后的层的反向传播输出。

3D DWT和IDWT的反向传播过程稍微复杂一点,但与1D/2D DWT和IDWT类似。本文使用有限滤波器,例如Haar小波,它的低通滤波和高通滤波可以表示为:

l

=

1

2

{

1

,

1

}

mathbf{l}=frac{1}{sqrt{2}}{1,1}

l=2
?1?{1,1},

h

=

1

2

{

1

,

?

1

}

mathbf{h}=frac{1}{sqrt{2}}{1,-1}

h=2
?1?{1,?1}。

在网络层中,对于多通道数据进行逐通道的DWT和IDWT操作。

3. WaveCNets网络模型

3.1 基于小波的通用去噪方法

给定一个2D的噪声数据

X

mathbf{X}

X,随机噪声主要表现在其高频成分中。如下图所示,基于小波的通用去噪方法[3],包括三个步骤:

  1. 使用DWT将带噪声的数据分解为低频成分

    X

    l

    l

    mathbf{X}_{ll}

    Xll? 和高频成分

    X

    l

    h

    ,

    X

    h

    l

    ,

    X

    h

    h

    mathbf{X}_{lh}, mathbf{X}_{hl}, mathbf{X}_{hh}

    Xlh?,Xhl?,Xhh?;

  2. 使用滤波器对高频成分进行过滤;
  3. 使用IDWT对处理后的高频和低频成分进行重构。

在这里插入图片描述

3.2 最简单的基于去噪方法的小波

本文选择最简单的基于去噪方法的小波,也就是丢弃高频成分,如下图所示:

在这里插入图片描述

其中,

D

W

T

l

l

mathrm{DWT}_{ll}

DWTll? 表示将特征图映射到低频成分的转换。

3.3 基于小波的下采样方法

本文通过用

D

W

T

l

l

mathrm{DWT}_{ll}

DWTll? 替换传统的下采样,设计出WaveCNets网络模型。如下图所示,(a) 表示传统的下采样方法,(b) 表示基于小波的下采样方法。

在这里插入图片描述

在WaveCNets网络中,将max-poolingaverage-pooling 直接替换为

D

W

T

l

l

mathrm{DWT}_{ll}

DWTll? 。同时,将 strided-convolution卷积替换为步长为1的卷积,也就是:

MaxPool

s

=

2

DWT

l

l

,

(

14

)

Conv

s

=

2

DWT

l

l

°

Conv

s

=

1

,

(

15

)

AvgPool

s

=

2

DWT

l

l

,

(

16

)

egin{aligned} ext{MaxPool}_{s=2}& o ext{DWT}_{ll},quad&(14)\ ext{Conv}_{s=2}& o ext{DWT}_{ll}circ ext{Conv}_{s=1},quad&(15)\ ext{AvgPool}_{s=2}& o ext{DWT}_{ll},quad&(16)end{aligned}

MaxPools=2?Convs=2?AvgPools=2??→DWTll?,→DWTll?°Convs=1?,→DWTll?,?(14)(15)(16)?
其中

M

a

x

p

o

o

l

s

mathrm {Maxpool_s}

Maxpools?、

C

o

n

v

s

mathrm {Conv_s}

Convs?、

A

v

g

P

o

o

l

s

mathrm {AvgPool_s}

AvgPools? 分别表示 max-poolingstrided-convolutionaverage-poolings表示步长(stride)。

3.4 WaveCNets模型的优势

D

W

T

l

l

mathrm{DWT}_{ll}

DWTll? 对特征图进行去噪,移除高频成分,特征图尺寸减半。

D

W

T

l

l

mathrm{DWT}_{ll}

DWTll? 输出的低频成分,保存了特征图的主要信息,并提取出可识别的特征。在WaveCNets下采样过程中,

D

W

T

l

l

mathrm{DWT}_{ll}

DWTll? 可以抵抗噪声的传播,有利于维持特征图中目标的基本结构。因此,

D

W

T

l

l

mathrm{DWT}_{ll}

DWTll? 可以加快深度网络的训练,有利于更好的噪声鲁棒性提高分类模型的精度

4. (TensorFlow)代码实现

Tensorflow实现小波池化层

四、参考文献

[1] Li Q, Shen L, Guo S, et al. Wavelet integrated CNNs for noise-robust image classification[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020: 7245-7254.

[2] Li Q, Shen L, Guo S, et al. Wavecnet: Wavelet integrated cnns to suppress aliasing effect for noise-robust image classification[J]. IEEE Transactions on Image Processing, 2021, 30: 7074-7089.

[3] Donoho D L. De-noising by soft-thresholding[J]. IEEE transactions on information theory, 1995, 41(3): 613-627.