InPeerReview's picture
Upload 4 files
3aafbf3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from thop import profile
from model.auxiliary import VSSM
import torch
from model.LaSEA import *
import torch
import time
from thop import profile
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class ResNet(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
if stride != 1 or out_channels != in_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels))
else:
self.shortcut = None
self.ca = ChannelAttention(out_channels)
self.sa = SpatialAttention()
def forward(self, x):
residual = x
if self.shortcut is not None:
residual = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.ca(out) * out
out = self.sa(out) * out
out += residual
out = self.relu(out)
return out
class DCCS(nn.Module):
def __init__(self, input_channels, block=ResNet):
super().__init__()
param_channels = [16, 32, 64, 128, 256]
param_blocks = [2, 2, 2, 2]
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
self.up_16 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.conv_init = nn.Conv2d(input_channels, param_channels[0], 1, 1)
self.encoder_0 = self._make_layer(param_channels[0], param_channels[0], block)
self.encoder_1 = self._make_layer(param_channels[0], param_channels[1], block, param_blocks[0])
self.encoder_2 = self._make_layer(param_channels[1], param_channels[2], block, param_blocks[1])
self.encoder_3 = self._make_layer(param_channels[2], param_channels[3], block, param_blocks[2])
self.middle_layer = self._make_layer(param_channels[3], param_channels[4], block, param_blocks[3])
self.decoder_3 = self._make_layer(param_channels[3] + param_channels[4], param_channels[3], block,
param_blocks[2])
self.decoder_2 = self._make_layer(param_channels[2] + param_channels[3], param_channels[2], block,
param_blocks[1])
self.decoder_1 = self._make_layer(param_channels[1] + param_channels[2], param_channels[1], block,
param_blocks[0])
self.decoder_0 = self._make_layer(param_channels[0] + param_channels[1], param_channels[0], block)
self.output_0 = nn.Conv2d(param_channels[0], 1, 1)
self.output_1 = nn.Conv2d(param_channels[1], 1, 1)
self.output_2 = nn.Conv2d(param_channels[2], 1, 1)
self.output_3 = nn.Conv2d(param_channels[3], 1, 1)
self.final = nn.Conv2d(4, 1, 3, 1, 1)
self.VSSM = VSSM()
self.post_fuse3 = nn.Conv2d(param_channels[3] * 2, param_channels[3], kernel_size=1)
self.post_fuse2 = nn.Conv2d(param_channels[2] * 2, param_channels[2], kernel_size=1)
self.post_fuse1 = nn.Conv2d(param_channels[1] * 2, param_channels[1], kernel_size=1)
self.post_fuse0 = nn.Conv2d(param_channels[0] * 2, param_channels[0], kernel_size=1)
self.GLFA = GLFA(in_channels=256)
def _make_layer(self, in_channels, out_channels, block, block_num=1):
layer = []
layer.append(block(in_channels, out_channels))
for _ in range(block_num - 1):
layer.append(block(out_channels, out_channels))
return nn.Sequential(*layer)
def forward(self, x, warm_flag):
outputs = self.VSSM(x)
x_e0f = outputs[0].permute(0, 3, 1, 2).contiguous()
x_e1f = outputs[1].permute(0, 3, 1, 2).contiguous()
x_e2f = outputs[2].permute(0, 3, 1, 2).contiguous()
x_e3f = outputs[3].permute(0, 3, 1, 2).contiguous()
x_e0z = self.encoder_0(self.conv_init(x))
x_e0 = torch.cat([x_e0z, x_e0f], dim=1)
x_e0z = self.post_fuse0(x_e0)
x_e1z = self.encoder_1(self.pool(x_e0z))
x_e1_fused = torch.cat([x_e1z, x_e1f], dim=1)
x_e1z = self.post_fuse1(x_e1_fused)
x_e2z = self.encoder_2(self.pool(x_e1z))
x_e2_fused = torch.cat([x_e2z, x_e2f], dim=1)
x_e2z = self.post_fuse2(x_e2_fused)
x_e3z = self.encoder_3(self.pool(x_e2z))
x_e3_fused = torch.cat([x_e3z, x_e3f], dim=1)
x_e3z = self.post_fuse3(x_e3_fused)
x_m = self.middle_layer(self.pool(x_e3z))
x_m = self.GLFA(x_m)
x_d3 = self.decoder_3(torch.cat([x_e3z, self.up(x_m)], 1))
x_d2 = self.decoder_2(torch.cat([x_e2z, self.up(x_d3)], 1))
x_d1 = self.decoder_1(torch.cat([x_e1z, self.up(x_d2)], 1))
x_d0 = self.decoder_0(torch.cat([x_e0z, self.up(x_d1)], 1))
if warm_flag:
mask0 = self.output_0(x_d0)
mask1 = self.output_1(x_d1)
mask2 = self.output_2(x_d2)
mask3 = self.output_3(x_d3)
output = self.final(torch.cat([mask0, self.up(mask1), self.up_4(mask2), self.up_8(mask3)], dim=1))
return [mask0, mask1, mask2, mask3], output
else:
output = self.output_0(x_d0)
return [], output