SURREAL-GAN
约 2432 个字 预计阅读时间 8 分钟
方法简介
背景
- 神经疾病具有异质性(多样性),即使被诊断为同一种神经系统疾病,不同患者之间的病理生理学机制也存在显著的差异
- 引入深度学习后,很多研究将人体健康状态二分,根据医学影像的特征给出离散的输出,但实际上健康变化是个连续的过程
- 医学影像的特征应该同时和时间和空间有关,为了解决上述问题这篇论文提出了 Surreal-GAN (Semi- Supervised Representation Learning via GAN)
关键点
represent complex disease-related heterogeneity with low dimensional representations with each dimension indicating the severity of one relatively homogeneous imaging patterns
具体的实现见后续章节,这里只列举要实现些什么?
- 建模疾病为一个连续的过程,学习CN(reference group)到PT(terget patient group)的无数种转换,其中每种都代表着特定模式和严重程度的组合,即GAN中的生成器
- 控制单调性,即隐变量越大,意味着生成的对应值越大
- 提高捕获的不同疾病模式的隔离性,即尽量让生成值的位置不同
- 确保通过转换合成的模式被准确捕获,用于推断疾病模式的表示
变量定义
这个模型被用于区域总量图,核心思想是学习一个变换函数$f: X * Z\rightarrow Y \(,或者写成\)y'=f(x,z) $,其中:
- $x $是CN中的数据
- $y' $是生成的假PT数据
- $z \(是一个用于声明不同模式和严重程度组合的隐变量,隐变量空间满足\)Z={z:0\leq z_i\leq 1|\forall 1\leq i \leq M}\(,是一族\)M $维向量
此处定义四个不同的数据分布:
- $x\sim p_{cn}(x) $
- $y\sim p_{pt}(y) $
- $y'\sim p_{syn}(y') $
- $z\sim p_{lat}(z) \(:采样自多变量的均值分布\)U[0, 1]^M $
此外,还要引入一个判别器$D \(用于判断\)y \(和\)y' \(的差距,生成函数\)f \(的目标自然是要让判别器无法分辨真假。但是光是连续的隐变量并不能满足得到的\)f $满足关键点中提到的那些性质,会引发如下问题:
- 有很多函数都可以满足变换,这样无法保证得到的$f $和病理过程紧密联系
- 尽管隐变量通过正则化被强调了,但是不同的隐变量并不能保证生成不同的疾病模式以及模式强度和隐变量正相关
不妨假设存在一个真正的变换函数,记作$y=h(x, z) \(,我们的核心目标就变成了让\)f \(逼近\)h \(。为了限制\)z \(是真的在为识别图片模式做贡献,而不是在添加随机误差,对\)f $满足以下要求:
- 促使在转换后结果是稀疏的
- 强制函数 Lipschitz 连续
- 引入一个反函数$g $用于分解和重建
- 促进生成的模式正交
- 强制生成的模式单调,与隐变量正相关
损失函数
这里的损失函数由多部分加和,每个部分都对应一种限制。
对抗损失
普通GAN的损失函数
正则化损失
稀疏转换
假定疾病并不会剧烈改变大脑的解剖结构,并且在大部分时间都只影响局部,定义如下改变损失:
LIPSCHITIZ 连续
首先,$f \(是K1-LIPSCHITIZ 连续,满足对于固定的\)z=a \(,\)\forall x_1,x_2\in X,||f(x_1,a)-f(x_2,a)||_2\leq K_1||X_1-X_2||_2 \(,这样通过控制\)K_1 $可以限制变换后的距离不至于变化太多。
然后为了避免$z \(被忽略,引入一个反映射函数\)g:Y\rightarrow Z \(,是K2-LIPSCHITIZ 连续,这里用\)d(.,.) \(表示任意维满足三角不等式的距离,注意此处\)z \(是自变量,\)x \(是函数,所以可以推出:\)\forall z_1,z_2\sim p_{lat}(z),z_1\neq z_2,\overline{x}\sim p_{cn}(x) \(,\)d(f(\overline{x}, z_1), f(\overline{x}, z_2)) \(的下界是\)\frac{d(z_1,z_2)}{K_2}-\frac{1}{K_2}(d(g(f(\overline{x},z_1)),z_1)+d(g(f(\overline{x},z_2)),z_2)) \((这个要用三角不等式的扩展形式证明),因此可以通过最小化\)z \(和\)g(f(x,z)) \(来让\)f \(的结果有明显区别,哪怕\)x $是一样的
综上所述,这里的损失定义为:
$L_{recons}(f, g)=E_{x\sim p_{cn}, z\sim p_{lat}(z)}[||g(f(x,z)-z||_2] $
模式分解
一个简单的$g \(并不能并不能完全基于正向过程中\)z_i \(带来的改变重建出\)z_i \(,我们更关心的是对于真实PT的模式表示,然后反向函数\)g \(可以用于预测R-incidices,因此反向函数必须要准确地捕获模式,因此把\)g \(分解为\)g_1:Y\rightarrow R^{M*S} \((\)M \(是模式数,\)S \(是输入的维度)和\)g_2: R^S\rightarrow R \(,\)g_1 \(用于重建生成自\)z_i \(的改变。此处,定义\)q_i=f(x,a^i)-x \(,其中\)a_i \(是一个向量,满足\)a_i^i=z_i \(和\)a^i_j=0 \forall i\neq j\(,然后\)\hat{q}_{f(x,z)}=[q_1^T,\dots,q_M^T] $是所有生成的改变的集合,分解的损失就可以定义如下:
$L_{decom}(f,g_1)=E_{x\sim p_{cn}(x), z\sim p_{lat}(z)}[||g_1(f(x,z)-\hat{q}_{f(x,z)})||_2] $
$g_2 $用于进一步重建隐变量的每个成分,因此重建损失可以被写作:
$L_{recons}(f,g)=E_{x\sim p_{cn}(x),z\sim p_{lat}(z)}[||\hat{I}{g_2(g_1(f(x,z)))}-z||_2] \(,其中\)\hat{I}=g(f(x,z)) $
正交模式
隐变量并不会保证每一位对不同模式起作用,相反它们更倾向于在同一区域生成模式,从而导致累加严重程度,因此添加正交损失来促进不同部分生成隔离的模式,每个组成导致的改变$q_i \(在上一小节定义了,此处构建一个对角矩阵\)A_{f(x,z)} \(,其每一列为\)\frac{|q_i|}{||q_i||_2} $,这个矩阵应该尽量接近正交矩阵,故损失函数定义如下:
$L_{ortho}=E_{x\sim p_{cn}(x), z\sim p_{lat}(z)}[||A_{(x,z)}^TA_{(x,z)}-I||_F] $
单调性和正相关
假定一个模式变严重了,对应区域改变的程度要么不变,要么增加,即不严格单增,为此再采一个样\(z'\sim p_{sev}(z'|z),z_i'\geq z_i,\forall 1\leq i\leq M\),基于双采样,定义单调损失如下:
$L_{mono}=E_{x\sim p_{cn}(x),z\sim p_{lat}(z),z'\sim p_{sev}(z'|z)}[||\max(|f(x,z)-x)|-|f(x,z')-x|,0||_2] $
但是这个只惩罚了不单调,没有限制小的隐变量就只引起小改变,为了进一步限制正相关,引入cn 损失,当隐变量$z \(是接近于0的向量时,令\)p_{cn}(z)=U(0, 0.05)^M $,cn 损失定义如下:
$L_{cn}(f)=E_{x\sim p_{cn}(x), z^{cn}\sim p_{cn}(z)}[||f(x,z^{cn})-x||_1] $
总结
将以上7个损失加和得到最终的损失:$L(D,f,g_1,g_2)=L_{GAN}(D,f)+\gamma L_{change}(f)+\kappa L_{decom}(f,g_1)+\zeta L_{recon}(f, g_1, g_2)+\lambda L_{ortho}(f)+\mu L_{cn}(f) \(,希腊字母都是可调的权重参数,在训练过程中希望参数化函数\)f \(,\)g_1 \(,\)g_2 \(满足\)f,g_1,g_2=\arg\min_{fg}\max_D L(D,f,g_1,g_2) $
相关知识
GAN
GAN的本质是两个神经网络的“对抗训练”:
- 生成器(Generator):输入随机噪声,输出/生成模仿真实数据的“假数据”,目标是让假数据“以假乱真”。
- 判别器(Discriminator):输入数据(真实数据/生成器的假数据),输出“数据为真的概率”(0~1),目标是精准区分真假。它是一个二分类器
两者不断对抗、互相优化,最终达到平衡:判别器分辨不出真假,生成器也能稳定输出高质量假数据,即最大最小博弈(Minimax Game),GAN的目标函数可以表示为:$\min_G\max_DV(D,G)=E_{x\sim p_{data}x}[\log D(x)] + E_{z\sim p_z(z)}[\log (1-D(G(z)))] $
符号含义:
- $D(x) \(:判别器判断数据\)x $为真的概率
- $G(z) \(:生成器把随机噪声\)z $转化成假数据的概率
- $p_{data}x $:真实数据分布
- $p_z(z) $:噪声的分布
博弈逻辑:
- 判别器$D \(要最大化\)V(D, G) \(,让真实数据的\)\log(D(x)) \(尽量大,让假数据的\)\log(1-D(G(z))) $也尽可能大,即真数据尽量判为真,假数据尽量判为假。
- 生成器$G \(让假数据的\)\log(1-D(G(z))) $尽可能小,即假数据尽量判为真。
LIPSCHITIZ 连续
LIPSCHITIZ 连续是函数一致性的强形式,其实我在NA里面学过这个,LIPSCHITIZ连续能够保证常微分方程初值问题解的唯一性。LIPSCHITIZ连续的定义如下:
一个函数$f \(被称为在某个区间或集合\)S \(上是Lipschitz连续的,如果存在一个非负实数\)K \(,使得对于\)S \(中的任意两个点\)x_1 \(和\)x_2 \(,以下的不等式成立:\)|f(x_1)-f(x_2)|\leq K|x_1-x_2| $,从几何上看就是函数的斜率绝对值有限。