import torch import torch.nn as nn import torch.nn.functional as F from thop import profile from model.auxiliary import VSSM import torch from model.LaSEA import * import torch import time from thop import profile class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) self.relu1 = nn.ReLU() self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out = avg_out + max_out return self.sigmoid(out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv1(x) return self.sigmoid(x) class ResNet(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(ResNet, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) if stride != 1 or out_channels != in_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), nn.BatchNorm2d(out_channels)) else: self.shortcut = None self.ca = ChannelAttention(out_channels) self.sa = SpatialAttention() def forward(self, x): residual = x if self.shortcut is not None: residual = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.ca(out) * out out = self.sa(out) * out out += residual out = self.relu(out) return out class DCCS(nn.Module): def __init__(self, input_channels, block=ResNet): super().__init__() param_channels = [16, 32, 64, 128, 256] param_blocks = [2, 2, 2, 2] self.pool = nn.MaxPool2d(2, 2) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) self.up_16 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) self.conv_init = nn.Conv2d(input_channels, param_channels[0], 1, 1) self.encoder_0 = self._make_layer(param_channels[0], param_channels[0], block) self.encoder_1 = self._make_layer(param_channels[0], param_channels[1], block, param_blocks[0]) self.encoder_2 = self._make_layer(param_channels[1], param_channels[2], block, param_blocks[1]) self.encoder_3 = self._make_layer(param_channels[2], param_channels[3], block, param_blocks[2]) self.middle_layer = self._make_layer(param_channels[3], param_channels[4], block, param_blocks[3]) self.decoder_3 = self._make_layer(param_channels[3] + param_channels[4], param_channels[3], block, param_blocks[2]) self.decoder_2 = self._make_layer(param_channels[2] + param_channels[3], param_channels[2], block, param_blocks[1]) self.decoder_1 = self._make_layer(param_channels[1] + param_channels[2], param_channels[1], block, param_blocks[0]) self.decoder_0 = self._make_layer(param_channels[0] + param_channels[1], param_channels[0], block) self.output_0 = nn.Conv2d(param_channels[0], 1, 1) self.output_1 = nn.Conv2d(param_channels[1], 1, 1) self.output_2 = nn.Conv2d(param_channels[2], 1, 1) self.output_3 = nn.Conv2d(param_channels[3], 1, 1) self.final = nn.Conv2d(4, 1, 3, 1, 1) self.VSSM = VSSM() self.post_fuse3 = nn.Conv2d(param_channels[3] * 2, param_channels[3], kernel_size=1) self.post_fuse2 = nn.Conv2d(param_channels[2] * 2, param_channels[2], kernel_size=1) self.post_fuse1 = nn.Conv2d(param_channels[1] * 2, param_channels[1], kernel_size=1) self.post_fuse0 = nn.Conv2d(param_channels[0] * 2, param_channels[0], kernel_size=1) self.GLFA = GLFA(in_channels=256) def _make_layer(self, in_channels, out_channels, block, block_num=1): layer = [] layer.append(block(in_channels, out_channels)) for _ in range(block_num - 1): layer.append(block(out_channels, out_channels)) return nn.Sequential(*layer) def forward(self, x, warm_flag): outputs = self.VSSM(x) x_e0f = outputs[0].permute(0, 3, 1, 2).contiguous() x_e1f = outputs[1].permute(0, 3, 1, 2).contiguous() x_e2f = outputs[2].permute(0, 3, 1, 2).contiguous() x_e3f = outputs[3].permute(0, 3, 1, 2).contiguous() x_e0z = self.encoder_0(self.conv_init(x)) x_e0 = torch.cat([x_e0z, x_e0f], dim=1) x_e0z = self.post_fuse0(x_e0) x_e1z = self.encoder_1(self.pool(x_e0z)) x_e1_fused = torch.cat([x_e1z, x_e1f], dim=1) x_e1z = self.post_fuse1(x_e1_fused) x_e2z = self.encoder_2(self.pool(x_e1z)) x_e2_fused = torch.cat([x_e2z, x_e2f], dim=1) x_e2z = self.post_fuse2(x_e2_fused) x_e3z = self.encoder_3(self.pool(x_e2z)) x_e3_fused = torch.cat([x_e3z, x_e3f], dim=1) x_e3z = self.post_fuse3(x_e3_fused) x_m = self.middle_layer(self.pool(x_e3z)) x_m = self.GLFA(x_m) x_d3 = self.decoder_3(torch.cat([x_e3z, self.up(x_m)], 1)) x_d2 = self.decoder_2(torch.cat([x_e2z, self.up(x_d3)], 1)) x_d1 = self.decoder_1(torch.cat([x_e1z, self.up(x_d2)], 1)) x_d0 = self.decoder_0(torch.cat([x_e0z, self.up(x_d1)], 1)) if warm_flag: mask0 = self.output_0(x_d0) mask1 = self.output_1(x_d1) mask2 = self.output_2(x_d2) mask3 = self.output_3(x_d3) output = self.final(torch.cat([mask0, self.up(mask1), self.up_4(mask2), self.up_8(mask3)], dim=1)) return [mask0, mask1, mask2, mask3], output else: output = self.output_0(x_d0) return [], output