|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Callable, Union, Tuple, Any |
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
import numpy as np |
|
|
from typing import Optional |
|
|
import math |
|
|
from torch import nn |
|
|
|
|
|
def makeDivisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: |
|
|
if min_value is None: |
|
|
min_value = divisor |
|
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
|
|
if new_v < 0.9 * v: |
|
|
new_v += divisor |
|
|
return new_v |
|
|
def callMethod(self, ElementName): |
|
|
return getattr(self, ElementName) |
|
|
def setMethod(self, ElementName, ElementValue): |
|
|
return setattr(self, ElementName, ElementValue) |
|
|
def shuffleTensor(Feature: Tensor, Mode: int=1) -> Tensor: |
|
|
if isinstance(Feature, Tensor): |
|
|
Feature = [Feature] |
|
|
Indexs = None |
|
|
Output = [] |
|
|
for f in Feature: |
|
|
B, C, H, W = f.shape |
|
|
if Mode == 1: |
|
|
f = f.flatten(2) |
|
|
if Indexs is None: |
|
|
Indexs = torch.randperm(f.shape[-1], device=f.device) |
|
|
f = f[:, :, Indexs.to(f.device)] |
|
|
f = f.reshape(B, C, H, W) |
|
|
else: |
|
|
if Indexs is None: |
|
|
Indexs = [torch.randperm(H, device=f.device), |
|
|
torch.randperm(W, device=f.device)] |
|
|
f = f[:, :, Indexs[0].to(f.device)] |
|
|
f = f[:, :, :, Indexs[1].to(f.device)] |
|
|
Output.append(f) |
|
|
return Output |
|
|
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): |
|
|
def __init__(self, output_size: int or tuple=1): |
|
|
super(AdaptiveAvgPool2d, self).__init__(output_size=output_size) |
|
|
|
|
|
def profileModule(self, Input: Tensor): |
|
|
Output = self.forward(Input) |
|
|
return Output, 0.0, 0.0 |
|
|
|
|
|
class AdaptiveMaxPool2d(nn.AdaptiveMaxPool2d): |
|
|
def __init__(self, output_size: int or tuple=1): |
|
|
super(AdaptiveMaxPool2d, self).__init__(output_size=output_size) |
|
|
|
|
|
def profileModule(self, Input: Tensor): |
|
|
Output = self.forward(Input) |
|
|
return Output, 0.0, 0.0 |
|
|
class BaseConv2d(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
kernel_size: int, |
|
|
stride: Optional[int] = 1, |
|
|
padding: Optional[int] = None, |
|
|
groups: Optional[int] = 1, |
|
|
bias: Optional[bool] = None, |
|
|
BNorm: bool = False, |
|
|
ActLayer: Optional[Callable[..., nn.Module]] = None, |
|
|
dilation: int = 1, |
|
|
Momentum: Optional[float] = 0.1, |
|
|
**kwargs: Any |
|
|
) -> None: |
|
|
super(BaseConv2d, self).__init__() |
|
|
if padding is None: |
|
|
padding = int((kernel_size - 1) // 2 * dilation) |
|
|
|
|
|
if bias is None: |
|
|
bias = not BNorm |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
self.kernel_size = kernel_size |
|
|
self.stride = stride |
|
|
self.padding = padding |
|
|
self.groups = groups |
|
|
self.bias = bias |
|
|
|
|
|
self.Conv = nn.Conv2d(in_channels, out_channels, |
|
|
kernel_size, stride, padding, dilation, groups, bias, **kwargs) |
|
|
|
|
|
self.Bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=Momentum) if BNorm else nn.Identity() |
|
|
|
|
|
if ActLayer is not None: |
|
|
if isinstance(list(ActLayer().named_modules())[0][1], nn.Sigmoid): |
|
|
self.Act = ActLayer() |
|
|
else: |
|
|
self.Act = ActLayer(inplace=True) |
|
|
else: |
|
|
self.Act = ActLayer |
|
|
|
|
|
self.apply(initWeight) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.Conv(x) |
|
|
x = self.Bn(x) |
|
|
if self.Act is not None: |
|
|
x = self.Act(x) |
|
|
return x |
|
|
|
|
|
NormLayerTuple = ( |
|
|
nn.BatchNorm1d, |
|
|
nn.BatchNorm2d, |
|
|
nn.SyncBatchNorm, |
|
|
nn.LayerNorm, |
|
|
nn.InstanceNorm1d, |
|
|
nn.InstanceNorm2d, |
|
|
nn.GroupNorm, |
|
|
nn.BatchNorm3d, |
|
|
) |
|
|
def initWeight(Module): |
|
|
if Module is None: |
|
|
return |
|
|
elif isinstance(Module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): |
|
|
nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5)) |
|
|
if Module.bias is not None: |
|
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight) |
|
|
if fan_in != 0: |
|
|
bound = 1 / math.sqrt(fan_in) |
|
|
nn.init.uniform_(Module.bias, -bound, bound) |
|
|
elif isinstance(Module, NormLayerTuple): |
|
|
if Module.weight is not None: |
|
|
nn.init.ones_(Module.weight) |
|
|
if Module.bias is not None: |
|
|
nn.init.zeros_(Module.bias) |
|
|
elif isinstance(Module, nn.Linear): |
|
|
nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5)) |
|
|
if Module.bias is not None: |
|
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight) |
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
|
|
nn.init.uniform_(Module.bias, -bound, bound) |
|
|
elif isinstance(Module, (nn.Sequential, nn.ModuleList)): |
|
|
for m in Module: |
|
|
initWeight(m) |
|
|
elif list(Module.children()): |
|
|
for m in Module.children(): |
|
|
initWeight(m) |
|
|
class Attention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
InChannels: int, |
|
|
HidChannels: int = None, |
|
|
SqueezeFactor: int = 4, |
|
|
PoolRes: list = [1, 2, 3], |
|
|
Act: Callable[..., nn.Module] = nn.ReLU, |
|
|
ScaleAct: Callable[..., nn.Module] = nn.Sigmoid, |
|
|
MoCOrder: bool = True, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
if HidChannels is None: |
|
|
HidChannels = max(makeDivisible(InChannels // SqueezeFactor, 8), 32) |
|
|
|
|
|
AllPoolRes = PoolRes + [1] if 1 not in PoolRes else PoolRes |
|
|
for k in AllPoolRes: |
|
|
Pooling = AdaptiveAvgPool2d(k) |
|
|
setMethod(self, 'Pool%d' % k, Pooling) |
|
|
|
|
|
self.SELayer = nn.Sequential( |
|
|
BaseConv2d(InChannels, HidChannels, 1, ActLayer=Act), |
|
|
BaseConv2d(HidChannels, InChannels, 1, ActLayer=ScaleAct), |
|
|
) |
|
|
|
|
|
self.PoolRes = PoolRes |
|
|
self.MoCOrder = MoCOrder |
|
|
|
|
|
def RandomSample(self, x: Tensor) -> Tensor: |
|
|
if self.training: |
|
|
PoolKeep = np.random.choice(self.PoolRes) |
|
|
x1 = shuffleTensor(x)[0] if self.MoCOrder else x |
|
|
AttnMap: Tensor = callMethod(self, 'Pool%d' % PoolKeep)(x1) |
|
|
if AttnMap.shape[-1] > 1: |
|
|
AttnMap = AttnMap.flatten(2) |
|
|
AttnMap = AttnMap[:, :, torch.randperm(AttnMap.shape[-1])[0]] |
|
|
AttnMap = AttnMap[:, :, None, None] |
|
|
else: |
|
|
AttnMap: Tensor = callMethod(self, 'Pool%d' % 1)(x) |
|
|
|
|
|
return AttnMap |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
AttnMap = self.RandomSample(x) |
|
|
return x * self.SELayer(AttnMap) |
|
|
|
|
|
def channel_shuffle(x, groups): |
|
|
batchsize, num_channels, height, width = x.data.size() |
|
|
channels_per_group = num_channels // groups |
|
|
x = x.view(batchsize, groups, channels_per_group, height, width) |
|
|
x = torch.transpose(x, 1, 2).contiguous() |
|
|
x = x.view(batchsize, -1, height, width) |
|
|
return x |
|
|
class GLFA(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super(GLFA, self).__init__() |
|
|
self.in_channels = in_channels |
|
|
self.out_channels = in_channels |
|
|
self.conv_1 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, dilation=1), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.conv_2 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, padding=2, kernel_size=3, dilation=2), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.conv_3 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, padding=3, kernel_size=3, dilation=3), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.conv_4 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, padding=4, kernel_size=3, dilation=4), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.fuse = nn.Sequential( |
|
|
nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.mca = Attention(InChannels=in_channels, HidChannels=16) |
|
|
def forward(self, x): |
|
|
d = x |
|
|
c1 = self.conv_1(x) |
|
|
c2 = self.conv_2(x) |
|
|
c3 = self.conv_3(x) |
|
|
c4 = self.conv_4(x) |
|
|
cat = torch.cat([c1, c2, c3, c4], dim=1) |
|
|
cat = channel_shuffle(cat, groups=4) |
|
|
M= self.fuse(cat) |
|
|
O = self.mca(M) |
|
|
return O + d |
|
|
|