InPeerReview's picture
Upload 4 files
3aafbf3 verified
import torch
import torch.nn as nn
from typing import Optional, Callable, Union, Tuple, Any
import torch
from torch import nn, Tensor
import numpy as np
from typing import Optional
import math
from torch import nn
def makeDivisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def callMethod(self, ElementName):
return getattr(self, ElementName)
def setMethod(self, ElementName, ElementValue):
return setattr(self, ElementName, ElementValue)
def shuffleTensor(Feature: Tensor, Mode: int=1) -> Tensor:
if isinstance(Feature, Tensor):
Feature = [Feature]
Indexs = None
Output = []
for f in Feature:
B, C, H, W = f.shape
if Mode == 1:
f = f.flatten(2)
if Indexs is None:
Indexs = torch.randperm(f.shape[-1], device=f.device)
f = f[:, :, Indexs.to(f.device)]
f = f.reshape(B, C, H, W)
else:
if Indexs is None:
Indexs = [torch.randperm(H, device=f.device),
torch.randperm(W, device=f.device)]
f = f[:, :, Indexs[0].to(f.device)]
f = f[:, :, :, Indexs[1].to(f.device)]
Output.append(f)
return Output
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
def __init__(self, output_size: int or tuple=1):
super(AdaptiveAvgPool2d, self).__init__(output_size=output_size)
def profileModule(self, Input: Tensor):
Output = self.forward(Input)
return Output, 0.0, 0.0
class AdaptiveMaxPool2d(nn.AdaptiveMaxPool2d):
def __init__(self, output_size: int or tuple=1):
super(AdaptiveMaxPool2d, self).__init__(output_size=output_size)
def profileModule(self, Input: Tensor):
Output = self.forward(Input)
return Output, 0.0, 0.0
class BaseConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: Optional[int] = 1,
padding: Optional[int] = None,
groups: Optional[int] = 1,
bias: Optional[bool] = None,
BNorm: bool = False,
ActLayer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
Momentum: Optional[float] = 0.1,
**kwargs: Any
) -> None:
super(BaseConv2d, self).__init__()
if padding is None:
padding = int((kernel_size - 1) // 2 * dilation)
if bias is None:
bias = not BNorm
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.bias = bias
self.Conv = nn.Conv2d(in_channels, out_channels,
kernel_size, stride, padding, dilation, groups, bias, **kwargs)
self.Bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=Momentum) if BNorm else nn.Identity()
if ActLayer is not None:
if isinstance(list(ActLayer().named_modules())[0][1], nn.Sigmoid):
self.Act = ActLayer()
else:
self.Act = ActLayer(inplace=True)
else:
self.Act = ActLayer
self.apply(initWeight)
def forward(self, x: Tensor) -> Tensor:
x = self.Conv(x)
x = self.Bn(x)
if self.Act is not None:
x = self.Act(x)
return x
NormLayerTuple = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.SyncBatchNorm,
nn.LayerNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.GroupNorm,
nn.BatchNorm3d,
)
def initWeight(Module):
if Module is None:
return
elif isinstance(Module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5))
if Module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(Module.bias, -bound, bound)
elif isinstance(Module, NormLayerTuple):
if Module.weight is not None:
nn.init.ones_(Module.weight)
if Module.bias is not None:
nn.init.zeros_(Module.bias)
elif isinstance(Module, nn.Linear):
nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5))
if Module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(Module.bias, -bound, bound)
elif isinstance(Module, (nn.Sequential, nn.ModuleList)):
for m in Module:
initWeight(m)
elif list(Module.children()):
for m in Module.children():
initWeight(m)
class Attention(nn.Module):
def __init__(
self,
InChannels: int,
HidChannels: int = None,
SqueezeFactor: int = 4,
PoolRes: list = [1, 2, 3],
Act: Callable[..., nn.Module] = nn.ReLU,
ScaleAct: Callable[..., nn.Module] = nn.Sigmoid,
MoCOrder: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
if HidChannels is None:
HidChannels = max(makeDivisible(InChannels // SqueezeFactor, 8), 32)
AllPoolRes = PoolRes + [1] if 1 not in PoolRes else PoolRes
for k in AllPoolRes:
Pooling = AdaptiveAvgPool2d(k)
setMethod(self, 'Pool%d' % k, Pooling)
self.SELayer = nn.Sequential(
BaseConv2d(InChannels, HidChannels, 1, ActLayer=Act),
BaseConv2d(HidChannels, InChannels, 1, ActLayer=ScaleAct),
)
self.PoolRes = PoolRes
self.MoCOrder = MoCOrder
def RandomSample(self, x: Tensor) -> Tensor:
if self.training:
PoolKeep = np.random.choice(self.PoolRes)
x1 = shuffleTensor(x)[0] if self.MoCOrder else x
AttnMap: Tensor = callMethod(self, 'Pool%d' % PoolKeep)(x1)
if AttnMap.shape[-1] > 1:
AttnMap = AttnMap.flatten(2)
AttnMap = AttnMap[:, :, torch.randperm(AttnMap.shape[-1])[0]]
AttnMap = AttnMap[:, :, None, None] # squeeze twice
else:
AttnMap: Tensor = callMethod(self, 'Pool%d' % 1)(x)
return AttnMap
def forward(self, x: Tensor) -> Tensor:
AttnMap = self.RandomSample(x)
return x * self.SELayer(AttnMap)
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, height, width)
return x
class GLFA(nn.Module):
def __init__(self, in_channels):
super(GLFA, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.conv_1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, dilation=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.conv_2 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, padding=2, kernel_size=3, dilation=2),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.conv_3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, padding=3, kernel_size=3, dilation=3),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.conv_4 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, padding=4, kernel_size=3, dilation=4),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.fuse = nn.Sequential(
nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.mca = Attention(InChannels=in_channels, HidChannels=16)
def forward(self, x):
d = x
c1 = self.conv_1(x)
c2 = self.conv_2(x)
c3 = self.conv_3(x)
c4 = self.conv_4(x)
cat = torch.cat([c1, c2, c3, c4], dim=1)
cat = channel_shuffle(cat, groups=4)
M= self.fuse(cat) #
O = self.mca(M)
return O + d