LYT-NET:一个超级轻量的低光照图像增强Transformer网络

张开发
2026/5/7 12:43:11 15 分钟阅读
LYT-NET:一个超级轻量的低光照图像增强Transformer网络
Paper:https://arxiv.org/pdf/2401.15204v6Code: https://github.com/albrateanu/LYT-Net目录0. 摘要1. 引言2. 方法2.1. CWD模块2.2. MHSA模块2.3. MSEF模块2.4. 损失函数3. 结果与讨论4. 消融实验5. 结论附【网络结构的Pytorch代码】0. 摘要本文提出了LYT-Net这是一个新颖的、轻量的、transformer-based的低光照图像增强模型它由几个层和可拆卸的块组成包括我们的新块——Channel-Wise Denoiser (CWD)和Multi-Stage Squeeze Excite Fusion (MSEF)——以及传统的Transformer块Multi-Headed Self-Attention (MHSA)。我们采用双路径方法将色度通道 U 和 V 和亮度通道 Y 视为单独的实体以帮助模型更好地处理光照调整和损坏恢复。我们对已建立的LLIE数据集的综合评估表明尽管它的复杂性较低但我们的模型优于最近的LLIE方法。1. 引言低光照图像增强(LLIE)是计算成像中一项重要且具有挑战性的任务。当图像在弱光条件下捕获时它们的质量通常会恶化从而导致细节和对比度的损失。这不仅使图像在视觉上不吸引人而且会影响许多成像系统的性能。LLIE 的目标是提高这些图像的清晰度和对比度同时还纠正黑暗环境中经常出现的失真所有这些都不引入不需要的伪影或导致颜色不平衡。早期的LLIE方法[1]主要依靠频率分解[2]、[3]、[4]、直方图均衡化[5]、[6]、[7]和Retinex理论[8]、[9]、[10]、[11]、[12]。随着深度学习的快速发展各种CNN架构[13][14][15][16][17][18][19][20][21][22]已被证明优于传统的LLIE技术。基于Retinex理论Retinex-Net[13]将Retinex分解与原始CNN架构相结合而DiffRetinex[16]提出了一个生成框架以进一步解决弱光引起的内容丢失和颜色偏差。生成对抗网络(GAN)[23]的发展为LLIE提供了新的视角其中弱光图像作为输入生成正常光对应图像。例如启蒙gan [24] 使用单个生成器模型直接将低光图像转换为正常光版本有效地在转换过程中同时使用全局和局部鉴别器。生成对抗网络(GAN)[23]的发展为LLIE提供了新的视角其中弱光图像作为输入生成正常光对应图像。例如启蒙gan [24] 使用单个生成器模型直接将低光图像转换为正常光版本有效地在转换过程中同时使用全局和局部鉴别器。最近视觉transformer(ViTs)[25]在各种计算机视觉任务中表现出了显著的有效性[26][27][28][29][30]这主要是归功于自注意力 (SA) 机制。尽管取得了这些进展但vit在低级low-level视觉任务中的应用仍未得到充分探索。在最近的文献[31]、[32]、[33]中只引入了一些基于LLIE(微光图像增强)-VIT的策略。Uformer [31] 基于经典的 UNet 架构其中卷积层被替换为 Transformer 块同时保持分层编码器-解码器结构和跳过连接。另一方面Restormer [33] 引入了多 Dconv 头转置注意力 (MDTA) 块取代了普通的多头自注意力。我们提出了一种新的轻量级基于transformer的方法称为LYT-Net。与现有的基于transformer的方法不同我们的方法专注于计算效率同时仍然产生最先进的(SOTA)结果。具体来说我们首先使用 YUV 颜色空间将色度与亮度分开。色度信息(通道U和V)最初通过专门的通道去噪器(CWD)块进行处理在保持精细细节的同时减少了噪声。为了降低计算复杂度亮度通道 Y 经历卷积和池化来提取特征随后由传统的多头自注意力 (MHSA) 块增强。然后通过一种新颖的多级SE融合 (MSEF) 块重新组合和处理增强通道。最后色度通道U和V通道与亮度Y通道连接并通过最后一组卷积层来生成恢复的图像。我们的方法对已建立的LLIE数据集进行了广泛的测试。定性和定量评估都表明我们的方法取得了极具竞争力的结果。图1展示了使用LOL数据集[13]评估的SOTA方法之间性能相对于复杂性的比较分析。可以看出尽管它的设计轻量级但我们的方法产生的结果不仅与最近更复杂的深度学习LLIE 技术的结果相当而且通常效果更好。2. 方法在图 2 中我们说明了 LYTNet 的整体架构它由几层和可拆卸块组成包括我们的新块——Channel-Wise Denoiser (CWD) 和多阶段 Squeeze Excite Fusion (MSEF)——以及传统的 ViT 块Multi-Headed Self-Attention (MHSA)。我们采用双路径方法将色度和亮度视为单独的实体以帮助模型更好地处理光照调整和损坏恢复。亮度通道Y经过卷积和池化提取特征然后由MHSAblock增强。通过CWD块处理色度通道U和V以减少噪声同时保留细节。然后通过MSEF块重新组合和处理增强的色度通道。最后将色度 U、V 和亮度 Y 通道连接起来并通过最后一组卷积层来生成输出从而产生高质量、增强的图像。2.1. CWD模块CWD块采用u型网络和MHSA作为瓶颈集成了卷积和注意力机制。它包括多个具有不同步幅和跳过连接的 conv3×3 层促进了详细的特征捕获和去噪。它由一系列四个 conv3×3 层组成。第一个conv3×3 在特征提取方面步幅为 1。其他三个 conv3×3 层的步长为 2有助于捕获不同尺度的特征。注意力瓶颈的集成使模型能够捕获长期依赖关系然后是上采样层和跳过连接来重建并促进空间分辨率的恢复。这种方法允许我们在空间维度降低的特征图上应用MHSA显著提高计算效率。此外使用基于插值的上采样而不是转置卷积将 CWD 中的参数数量减少了一半以上同时保留了性能。2.2. MHSA模块在我们简化版的transformer架构中输入特征首先通过无偏置全连接层线性投影到查询(Q)、键(K)和值(V)分量。线性投影保持原始输入维度。接下来这些投影特征被分成 k 个头其中每个头都以维度 d_k 独立运行。自注意力机制应用于每个头部定义如下最后将所有头部的注意力输出连接起来组合输出通过线性层将其投影回原始嵌入大小。输出标记被重新整形回原始空间维度以形成输出特征。2.3. MSEF模块MSEF块增强了的空间和通道特征。最初经历层归一化然后是全局平均池化来捕获全局空间上下文和具有 ReLU 激活的缩减全连接层产生减少的描述符如公式 (4)。然后该描述符通过另一个具有 Tanh 激活的全连接层扩展到原始维度从而产生, 如公式 (5)。在融合输出中加入残差连接生成最终的输出特征图如式(6)所示。2.4. 损失函数在我们的方法中混合损失函数在有效地训练我们的模型方面起着关键作用。混合损失L如式(7)所示其中α1到α5是用于平衡每个组成损失函数的超参数。我们模型中的混合损失结合了几个组件来提高图像质量和感知。平滑 L1 损失 LS 通过基于预测值和真实值之间的差异应用二次或线性惩罚来处理异常值。感知损失 LPerc 通过比较 VGG 提取的特征图来保持特征一致性。直方图损失LHist对齐预测图像和真实图像之间的像素强度分布。PSNR损失LPSNR通过惩罚均方误差来降低噪声而颜色损失LColor通过最小化通道平均值的差异来确保颜色保真度。最后多尺度SSIM损失LMS-SSIM通过在多个尺度上评估相似性来保持结构的完整性。总之这些损失形成了一个综合策略解决了图像增强的各个方面。3. 结果与讨论实现细节LYT-Net 的实现利用了 TensorFlow 框架。ADAM 优化器 (β1 0.9 and β2 0.999) 用于超过 1000 个 epoch 的训练。初始学习率设置为 2×10−4并在余弦退火计划后逐渐衰减到 1 × 10−6有助于优化收敛并避免局部最小值。混合损失函数的超参数设置为α10.06、α20.05、α30.5、α40.0083 和 α50.25。 LYT-Net 在 LOL 数据集的三个版本上进行训练和评估LOL-v1、LOL-v2-real 和 LOL-v2-synthetic。LOLv1、LOL-v2-real 的相应训练/测试拆分为 485 : 15LOL-v2-real 为 689 : 100LOL-v2-synthetic 为 900 : 100。在训练期间图像对进行随机增强包括随机裁剪到 256 × 256 和随机翻转/旋转以防止过度拟合。训练以 1 的批大小进行。评估指标包括 PSNR 和 SSIM 进行性能评估。PS官方的Gtihub代码里面同时也包含了Pytorch版本。定量结果将所提出的方法与 SOTA LLIE 技术进行比较如表 I 所示重点关注两个方面LOL 数据集LOLV1、LOL-v2-real、LOL-v2-synthetic和模型复杂度的定量性能。如表 I 所示LYT-Net 在 PSNR 和 SSIM 方面在所有版本的 LOL 数据集上始终优于当前的 SOTA 方法。此外LYTNet 非常高效只需要 3.49G FLOPS 且仅使用 0.045M 参数这使得它比其他通常更复杂的 SOTA 方法具有显着优势。唯一的例外是 3DLUT[35]它在复杂性方面与我们的方法相当。然而LYT-Net 在 PSNR 和 SSIM 中明显优于 3DLUT 方法。这种强大的性能和低复杂度的组合突出了 LYT-Net 的整体有效性。定性结果:LOL数据集上的LYT-Net与SOTALLIE技术的定性评估如图3所示LIME[38]上的图4所示。以前的方法如KiND[17]和Restormer[33]表现出颜色失真问题如图3所示。此外几种算法(如MIRNet[20]和SNR-Net[22])往往会产生过度曝光或曝光不足的区域在增强亮度的同时损害图像对比度。同样图 4 表明 SRIE [39]、DeHz [40] 和 NPE [41] 导致对比度损失。一般来说我们的LYT-Net在提高能见度和增强低对比度或光线较差的区域方面非常有效同时在不引入斑点或伪影的情况下有效地消除噪声。4. 消融实验消融研究是在LOLV1数据集上进行的使用PSNR作为定量指标并评估CWD和MSEF块的影响。在 YUV 分解中将 CWD 应用于 Y 通道用作照明图会导致照明伪影的保留导致与池化操作和基于插值的上采样相比性能下降从而平滑照明以获得更好的结果。然而CWD增强了色度通道(U和V)在不引入噪声的情况下保留细节。此外MSEF 块始终在所有 CWD 组合中提高性能PSNR 分别提高了 0.16、0.24 和 0.26 dB同时将参数计数提高了 546。5. 结论我们引入了LYT-Net这是一种创新的基于transformer的轻量级模型用于增强低光照图像。我们的方法利用双路径框架分别处理色度和亮度以提高模型管理光照调整和恢复损坏区域的能力。LYT-Net 集成了多层和模块化块包括两个独特的组件——ChannelWise Denoiser (CWD) 和多阶段 Squeeze Excite Fusion (MSEF)——以及具有多头自注意力 (MHSA) 的传统视觉transformer (ViT) 块。全面的定性和定量分析表明LYT-Net 在 PSNR 和 SSIM 方面在所有版本的 LOL 数据集上始终优于 SOTA 方法同时保持了较高的计算效率。附LYT-Net 网络结构的Pytorch代码来自官方实现import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init class LayerNormalization(nn.Module): def __init__(self, dim): super(LayerNormalization, self).__init__() self.norm nn.LayerNorm(dim) def forward(self, x): # Rearrange the tensor for LayerNorm (B, C, H, W) to (B, H, W, C) x x.permute(0, 2, 3, 1) x self.norm(x) # Rearrange back to (B, C, H, W) return x.permute(0, 3, 1, 2) class SEBlock(nn.Module): def __init__(self, input_channels, reduction_ratio16): super(SEBlock, self).__init__() self.pool nn.AdaptiveAvgPool2d(1) self.fc1 nn.Linear(input_channels, input_channels // reduction_ratio) self.fc2 nn.Linear(input_channels // reduction_ratio, input_channels) self._init_weights() def forward(self, x): batch_size, num_channels, _, _ x.size() y self.pool(x).reshape(batch_size, num_channels) y F.relu(self.fc1(y)) y torch.tanh(self.fc2(y)) y y.reshape(batch_size, num_channels, 1, 1) return x * y def _init_weights(self): init.kaiming_uniform_(self.fc1.weight, a0, modefan_in, nonlinearityrelu) init.kaiming_uniform_(self.fc2.weight, a0, modefan_in, nonlinearityrelu) init.constant_(self.fc1.bias, 0) init.constant_(self.fc2.bias, 0) class MSEFBlock(nn.Module): def __init__(self, filters): super(MSEFBlock, self).__init__() self.layer_norm LayerNormalization(filters) self.depthwise_conv nn.Conv2d(filters, filters, kernel_size3, padding1, groupsfilters) self.se_attn SEBlock(filters) self._init_weights() def forward(self, x): x_norm self.layer_norm(x) x1 self.depthwise_conv(x_norm) x2 self.se_attn(x_norm) x_fused x1 * x2 x_out x_fused x return x_out def _init_weights(self): init.kaiming_uniform_(self.depthwise_conv.weight, a0, modefan_in, nonlinearityrelu) init.constant_(self.depthwise_conv.bias, 0) class MultiHeadSelfAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadSelfAttention, self).__init__() self.embed_size embed_size self.num_heads num_heads assert embed_size % num_heads 0 self.head_dim embed_size // num_heads self.query_dense nn.Linear(embed_size, embed_size) self.key_dense nn.Linear(embed_size, embed_size) self.value_dense nn.Linear(embed_size, embed_size) self.combine_heads nn.Linear(embed_size, embed_size) self._init_weights() def split_heads(self, x, batch_size): x x.reshape(batch_size, -1, self.num_heads, self.head_dim) return x.permute(0, 2, 1, 3) def forward(self, x): batch_size, _, height, width x.size() x x.reshape(batch_size, height * width, -1) query self.split_heads(self.query_dense(x), batch_size) key self.split_heads(self.key_dense(x), batch_size) value self.split_heads(self.value_dense(x), batch_size) attention_weights F.softmax(torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5), dim-1) attention torch.matmul(attention_weights, value) attention attention.permute(0, 2, 1, 3).contiguous().reshape(batch_size, -1, self.embed_size) output self.combine_heads(attention) return output.reshape(batch_size, height, width, self.embed_size).permute(0, 3, 1, 2) def _init_weights(self): init.xavier_uniform_(self.query_dense.weight) init.xavier_uniform_(self.key_dense.weight) init.xavier_uniform_(self.value_dense.weight) init.xavier_uniform_(self.combine_heads.weight) init.constant_(self.query_dense.bias, 0) init.constant_(self.key_dense.bias, 0) init.constant_(self.value_dense.bias, 0) init.constant_(self.combine_heads.bias, 0) class Denoiser(nn.Module): def __init__(self, num_filters, kernel_size3, activationrelu): super(Denoiser, self).__init__() self.conv1 nn.Conv2d(1, num_filters, kernel_sizekernel_size, padding1) self.conv2 nn.Conv2d(num_filters, num_filters, kernel_sizekernel_size, stride2, padding1) self.conv3 nn.Conv2d(num_filters, num_filters, kernel_sizekernel_size, stride2, padding1) self.conv4 nn.Conv2d(num_filters, num_filters, kernel_sizekernel_size, stride2, padding1) self.bottleneck MultiHeadSelfAttention(embed_sizenum_filters, num_heads4) self.up2 nn.Upsample(scale_factor2, modenearest) self.up3 nn.Upsample(scale_factor2, modenearest) self.up4 nn.Upsample(scale_factor2, modenearest) self.output_layer nn.Conv2d(1, 1, kernel_sizekernel_size, padding1) self.res_layer nn.Conv2d(num_filters, 1, kernel_sizekernel_size, padding1) self.activation getattr(F, activation) self._init_weights() def forward(self, x): x1 self.activation(self.conv1(x)) x2 self.activation(self.conv2(x1)) x3 self.activation(self.conv3(x2)) x4 self.activation(self.conv4(x3)) x self.bottleneck(x4) x self.up4(x) x self.up3(x x3) x self.up2(x x2) x x x1 x self.res_layer(x) return torch.tanh(self.output_layer(x x)) def _init_weights(self): for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.output_layer, self.res_layer]: init.kaiming_uniform_(layer.weight, a0, modefan_in, nonlinearityrelu) if layer.bias is not None: init.constant_(layer.bias, 0) class LYT(nn.Module): def __init__(self, filters32): super(LYT, self).__init__() self.process_y self._create_processing_layers(filters) self.process_cb self._create_processing_layers(filters) self.process_cr self._create_processing_layers(filters) self.denoiser_cb Denoiser(filters // 2) self.denoiser_cr Denoiser(filters // 2) self.lum_pool nn.MaxPool2d(8) self.lum_mhsa MultiHeadSelfAttention(embed_sizefilters, num_heads4) self.lum_up nn.Upsample(scale_factor8, modenearest) self.lum_conv nn.Conv2d(filters, filters, kernel_size1, padding0) self.ref_conv nn.Conv2d(filters * 2, filters, kernel_size1, padding0) self.msef MSEFBlock(filters) self.recombine nn.Conv2d(filters * 2, filters, kernel_size3, padding1) self.final_adjustments nn.Conv2d(filters, 3, kernel_size3, padding1) self._init_weights() def _create_processing_layers(self, filters): return nn.Sequential( nn.Conv2d(1, filters, kernel_size3, padding1), nn.ReLU(inplaceTrue) ) def _rgb_to_ycbcr(self, image): r, g, b image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :] y 0.299 * r 0.587 * g 0.114 * b u -0.14713 * r - 0.28886 * g 0.436 * b 0.5 v 0.615 * r - 0.51499 * g - 0.10001 * b 0.5 yuv torch.stack((y, u, v), dim1) return yuv def forward(self, inputs): ycbcr self._rgb_to_ycbcr(inputs) y, cb, cr torch.split(ycbcr, 1, dim1) cb self.denoiser_cb(cb) cb cr self.denoiser_cr(cr) cr y_processed self.process_y(y) cb_processed self.process_cb(cb) cr_processed self.process_cr(cr) ref torch.cat([cb_processed, cr_processed], dim1) lum y_processed lum_1 self.lum_pool(lum) lum_1 self.lum_mhsa(lum_1) lum_1 self.lum_up(lum_1) lum lum lum_1 ref self.ref_conv(ref) shortcut ref ref ref 0.2 * self.lum_conv(lum) ref self.msef(ref) ref ref shortcut recombined self.recombine(torch.cat([ref, lum], dim1)) output self.final_adjustments(recombined) return torch.sigmoid(output) def _init_weights(self): for module in self.children(): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): init.kaiming_uniform_(module.weight, a0, modefan_in, nonlinearityrelu) if module.bias is not None: init.constant_(module.bias, 0)

更多文章