InPeerReview commited on
Commit
2ff0f4b
·
verified ·
1 Parent(s): 5243f5d

Upload 6 files

Browse files
Files changed (6) hide show
  1. model/decoder.py +309 -0
  2. model/encoder.py +391 -0
  3. model/metric_tool.py +131 -0
  4. model/resnet.py +213 -0
  5. model/trainer.py +30 -0
  6. model/utils.py +81 -0
model/decoder.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from model.utils import weight_init
6
+
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
10
+ if drop_prob == 0. or not training:
11
+ return x
12
+ keep_prob = 1 - drop_prob
13
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
14
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
15
+ random_tensor.floor_() # binarize
16
+ output = x.div(keep_prob) * random_tensor
17
+ return output
18
+
19
+
20
+ class DropPath(nn.Module):
21
+ def __init__(self, drop_prob=None):
22
+ super(DropPath, self).__init__()
23
+ self.drop_prob = drop_prob
24
+
25
+ def forward(self, x):
26
+ return drop_path(x, self.drop_prob, self.training)
27
+
28
+
29
+ class Mlp(nn.Module):
30
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
31
+ super().__init__()
32
+ out_features = out_features or in_features
33
+ hidden_features = hidden_features or in_features
34
+ self.fc1 = nn.Linear(in_features, hidden_features)
35
+ self.act = act_layer()
36
+ self.fc2 = nn.Linear(hidden_features, out_features)
37
+ self.drop = nn.Dropout(drop)
38
+
39
+ def forward(self, x):
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop(x)
43
+ x = self.fc2(x)
44
+ x = self.drop(x)
45
+ return x
46
+
47
+
48
+
49
+ class CrossAttention(nn.Module):
50
+ def __init__(self, dim1, dim2, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
51
+ super().__init__()
52
+ self.num_heads = num_heads
53
+ head_dim = dim1 // num_heads
54
+ self.scale = head_dim ** -0.5
55
+
56
+ self.q = nn.Linear(dim1, dim1, bias=qkv_bias)
57
+ self.kv = nn.Linear(dim2, dim1 * 2, bias=qkv_bias)
58
+
59
+ self.attn_drop = nn.Dropout(attn_drop)
60
+ self.proj = nn.Linear(dim1, dim1)
61
+ self.proj_drop = nn.Dropout(proj_drop)
62
+
63
+ def forward(self, x, y):
64
+ B1, N1, C1 = x.shape
65
+ B2, N2, C2 = y.shape
66
+
67
+ q = self.q(x).reshape(B1, N1, self.num_heads, C1 // self.num_heads).permute(0, 2, 1, 3)
68
+ kv = self.kv(y).reshape(B2, N2, 2, self.num_heads, C1 // self.num_heads).permute(2, 0, 3, 1, 4)
69
+
70
+ k, v = kv[0], kv[1]
71
+
72
+ attn = (q @ k.transpose(-2, -1)) * self.scale
73
+ attn = attn.softmax(dim=-1)
74
+ attn = self.attn_drop(attn)
75
+
76
+ x = (attn @ v).transpose(1, 2).reshape(B1, N1, C1)
77
+
78
+ x = self.proj(x)
79
+ x = self.proj_drop(x)
80
+
81
+ return x
82
+
83
+
84
+
85
+ class Block(nn.Module):
86
+ def __init__(self, dim1, dim2, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
87
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
88
+ super().__init__()
89
+ self.norm1 = norm_layer(dim1)
90
+ self.norm2 = norm_layer(dim2)
91
+ self.attn = CrossAttention(dim1, dim2, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
92
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
93
+ self.norm3 = norm_layer(dim1)
94
+ mlp_hidden_dim = int(dim1 * mlp_ratio)
95
+ self.mlp = Mlp(in_features=dim1, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
96
+
97
+ def forward(self, x, y):
98
+ x = x + self.drop_path(self.attn(self.norm1(x), self.norm2(y)))
99
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
100
+ return x
101
+
102
+
103
+
104
+ class ContentAwareAggregation(nn.Module):
105
+ def __init__(self, low_dim, high_dim):
106
+ super().__init__()
107
+ self.project = nn.Sequential(
108
+ nn.Conv2d(high_dim, low_dim, kernel_size=1),
109
+ nn.BatchNorm2d(low_dim),
110
+ nn.ReLU(inplace=True)
111
+ )
112
+
113
+ self.attn_gen = nn.Sequential(
114
+ nn.Conv2d(low_dim, low_dim, kernel_size=3, padding=1, groups=low_dim),
115
+ nn.BatchNorm2d(low_dim),
116
+ nn.ReLU(inplace=True),
117
+ nn.Conv2d(low_dim, low_dim, kernel_size=1),
118
+ nn.Sigmoid()
119
+ )
120
+
121
+ def forward(self, low_feat, high_feat):
122
+ high_feat = F.interpolate(high_feat, size=low_feat.shape[2:], mode='bilinear', align_corners=False)
123
+ high_feat = self.project(high_feat)
124
+ attn = self.attn_gen(low_feat + high_feat)
125
+ out = attn * low_feat + high_feat
126
+ return out
127
+
128
+
129
+
130
+ class FeatureInjector(nn.Module):
131
+ def __init__(self, dim1=384, dim2=[64, 128, 256], num_heads=8, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
132
+ drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
133
+ super().__init__()
134
+
135
+ self.c2_c5 = Block(dim1, dim2[0], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
136
+ self.c3_c5 = Block(dim1, dim2[1], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
137
+ self.c4_c5 = Block(dim1, dim2[2], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
138
+
139
+ self.fuse = nn.Conv2d(dim1*3, dim1, 1, bias=False)
140
+ self.caa = ContentAwareAggregation(dim1, dim1)
141
+
142
+ weight_init(self)
143
+
144
+ def base_forward(self, c2, c3, c4, c5):
145
+ H, W = c5.shape[2:]
146
+
147
+ c2 = rearrange(c2, 'b c h w -> b (h w) c')
148
+ c3 = rearrange(c3, 'b c h w -> b (h w) c')
149
+ c4 = rearrange(c4, 'b c h w -> b (h w) c')
150
+ c5 = rearrange(c5, 'b c h w -> b (h w) c')
151
+
152
+ _c2 = self.c2_c5(c5, c2)
153
+ _c2 = rearrange(_c2, 'b (h w) c -> b c h w', h=H, w=W)
154
+
155
+ _c3 = self.c3_c5(c5, c3)
156
+ _c3 = rearrange(_c3, 'b (h w) c -> b c h w', h=H, w=W)
157
+
158
+ _c4 = self.c4_c5(c5, c4)
159
+ _c4 = rearrange(_c4, 'b (h w) c -> b c h w', h=H, w=W)
160
+
161
+ _c5 = self.fuse(torch.cat([_c2, _c3, _c4], dim=1))
162
+
163
+ return _c5
164
+
165
+ def forward(self, fx, fy):
166
+ _c5x = self.base_forward(fx[0], fx[1], fx[2], fx[3])
167
+ _c5y = self.base_forward(fy[0], fy[1], fy[2], fy[3])
168
+
169
+
170
+ _c5x = self.caa(_c5x, _c5y)
171
+ _c5y = self.caa(_c5y, _c5x)
172
+
173
+ return _c5x, _c5y
174
+
175
+
176
+ class DualAttentionGate(nn.Module):
177
+ def __init__(self, channels, ratio=8):
178
+ super().__init__()
179
+ # 通道注意力分支
180
+ self.channel_att = nn.Sequential(
181
+ nn.AdaptiveAvgPool2d(1), # [B,C,1,1]
182
+ nn.Conv2d(channels, channels // ratio, 1, bias=False), # [B,C/8,1,1]
183
+ nn.ReLU(),
184
+ nn.Conv2d(channels // ratio, channels, 1, bias=False), # [B,C,1,1]
185
+ nn.Sigmoid()
186
+ )
187
+
188
+ # 空间注意力分支
189
+ self.spatial_att = nn.Sequential(
190
+ nn.Conv2d(2, 1, 7, padding=3, bias=False), # 输入2通道(mean+std)
191
+ nn.Sigmoid() # 输出[B,1,H,W]
192
+ )
193
+
194
+ def forward(self, x):
195
+ """
196
+ 输入: x [B,C,H,W]
197
+ 输出: 增强后的特征 [B,C,H,W]
198
+ """
199
+ # 通道注意力
200
+ c_att = self.channel_att(x) # [B,C,1,1]
201
+
202
+ # 空间注意力
203
+ mean = torch.mean(x, dim=1, keepdim=True) # [B,1,H,W]
204
+ std = torch.std(x, dim=1, keepdim=True) # [B,1,H,W]
205
+ s_att = self.spatial_att(torch.cat([mean, std], dim=1)) # [B,1,H,W]
206
+
207
+ # 双重注意力融合
208
+ return x * c_att * s_att # 逐元素相乘
209
+
210
+
211
+ class SimplifiedFGFM(nn.Module):
212
+ def __init__(self, in_channels, out_channels):
213
+ super().__init__()
214
+ self.down = nn.Conv2d(in_channels, out_channels, 1, bias=False)
215
+ self.flow_make = nn.Conv2d(out_channels * 2, 4, 3, padding=1, bias=False)
216
+ self.dual_att = DualAttentionGate(out_channels)
217
+
218
+ def flow_warp(self, input, flow, size):
219
+ # 保持原有光流变形实现
220
+ out_h, out_w = size
221
+ n, c, h, w = input.size()
222
+
223
+ norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device)
224
+ grid = torch.meshgrid(
225
+ torch.linspace(-1.0, 1.0, out_h),
226
+ torch.linspace(-1.0, 1.0, out_w),
227
+ indexing='ij'
228
+ )
229
+ grid = torch.stack((grid[1], grid[0]), 2).repeat(n, 1, 1, 1).type_as(input)
230
+ grid = grid + flow.permute(0, 2, 3, 1) / norm
231
+
232
+ return F.grid_sample(input, grid, align_corners=True)
233
+
234
+ def forward(self, lowres_feature, highres_feature):
235
+ # 1. 光流对齐
236
+ l_feature = self.down(lowres_feature)
237
+ l_feature_up = F.interpolate(l_feature, size=highres_feature.shape[2:], mode='bilinear', align_corners=True)
238
+
239
+ flow = self.flow_make(torch.cat([l_feature_up, highres_feature], dim=1))
240
+ flow_l, flow_h = flow[:, :2, :, :], flow[:, 2:, :, :]
241
+
242
+ l_warp = self.flow_warp(l_feature, flow_l, highres_feature.shape[2:])
243
+ h_warp = self.flow_warp(highres_feature, flow_h, highres_feature.shape[2:])
244
+
245
+ # 2. 双注意力融合
246
+ fused = self.dual_att(l_warp + h_warp)
247
+ return fused
248
+
249
+
250
+ # Decoder 模块
251
+ class Decoder(nn.Module):
252
+ def __init__(self, in_dim=[64, 128, 256, 384], decay=4, num_class=1):
253
+ super().__init__()
254
+ c2_channel, c3_channel, c4_channel, c5_channel = in_dim
255
+
256
+ self.structure_enhance = FeatureInjector(dim1=c5_channel)
257
+
258
+ # 使用改进的 SimplifiedFGFM 模块替换传统上采样
259
+ self.fgfm_c4 = SimplifiedFGFM(in_channels=c5_channel, out_channels=c4_channel)
260
+ self.fgfm_c3 = SimplifiedFGFM(in_channels=c4_channel, out_channels=c3_channel)
261
+ self.fgfm_c2 = SimplifiedFGFM(in_channels=c3_channel, out_channels=c2_channel)
262
+
263
+ # 最终分类器
264
+ self.classfier = nn.Sequential(
265
+ nn.ConvTranspose2d(c2_channel, c2_channel, kernel_size=4, stride=2, padding=1),
266
+ nn.Conv2d(c2_channel, num_class, 3, 1, padding=1, bias=False)
267
+ )
268
+
269
+ # 各层级的差异建模模块(MLP)
270
+ self.mlp = nn.ModuleList([
271
+ nn.Sequential(
272
+ nn.Conv2d(dim * 3, dim // decay, 1, bias=False),
273
+ nn.BatchNorm2d(dim // decay),
274
+ nn.ReLU(),
275
+ nn.Conv2d(dim // decay, dim // decay, 3, 1, padding=1, bias=False),
276
+ nn.ReLU(),
277
+ nn.Conv2d(dim // decay, dim // decay, 3, 1, padding=1, bias=False),
278
+ nn.ReLU(),
279
+ nn.Conv2d(dim // decay, dim, 3, 1, padding=1, bias=False)
280
+ ) for dim in in_dim
281
+ ])
282
+
283
+ def difference_modeling(self, x, y, block):
284
+ f = torch.cat([x, y, torch.abs(x - y)], dim=1)
285
+ return block(f)
286
+
287
+ def forward(self, fx, fy):
288
+ c2x, c3x, c4x = fx[:-1]
289
+ c2y, c3y, c4y = fy[:-1]
290
+
291
+ # 融合后的高阶语义特征(c5)
292
+ c5x, c5y = self.structure_enhance(fx, fy)
293
+
294
+ # 各层特征差异建模
295
+ c2 = self.difference_modeling(c2x, c2y, self.mlp[0])
296
+ c3 = self.difference_modeling(c3x, c3y, self.mlp[1])
297
+ c4 = self.difference_modeling(c4x, c4y, self.mlp[2])
298
+ c5 = self.difference_modeling(c5x, c5y, self.mlp[3])
299
+
300
+ # 使用改进的 FGFM 进行流引导特征融合
301
+ c4f = self.fgfm_c4(c5, c4)
302
+ c3f = self.fgfm_c3(c4f, c3)
303
+ c2f = self.fgfm_c2(c3f, c2)
304
+
305
+ # 输出变化掩码
306
+ pred = self.classfier(c2f)
307
+ pred_mask = torch.sigmoid(pred)
308
+
309
+ return pred_mask
model/encoder.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from einops import rearrange
20
+
21
+ from model.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22
+ from model.resnet import resnet18
23
+
24
+
25
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
26
+ if not depth_first and include_root:
27
+ fn(module=module, name=name)
28
+ for child_name, child_module in module.named_children():
29
+ child_name = ".".join((name, child_name)) if name else child_name
30
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
31
+ if depth_first and include_root:
32
+ fn(module=module, name=name)
33
+ return module
34
+
35
+
36
+ class BlockChunk(nn.ModuleList):
37
+ def forward(self, x):
38
+ for b in self:
39
+ x = b(x)
40
+ return x
41
+
42
+
43
+ class DinoVisionTransformer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ img_size=224,
47
+ patch_size=16,
48
+ in_chans=3,
49
+ embed_dim=768,
50
+ depth=12,
51
+ num_heads=12,
52
+ mlp_ratio=4.0,
53
+ qkv_bias=True,
54
+ ffn_bias=True,
55
+ proj_bias=True,
56
+ drop_path_rate=0.0,
57
+ drop_path_uniform=False,
58
+ init_values=None, # for layerscale: None or 0 => no layerscale
59
+ embed_layer=PatchEmbed,
60
+ act_layer=nn.GELU,
61
+ block_fn=Block,
62
+ ffn_layer="mlp",
63
+ block_chunks=0,
64
+ num_register_tokens=0,
65
+ interpolate_antialias=False,
66
+ interpolate_offset=0.1,
67
+ ):
68
+ """
69
+ Args:
70
+ img_size (int, tuple): input image size
71
+ patch_size (int, tuple): patch size
72
+ in_chans (int): number of input channels
73
+ embed_dim (int): embedding dimension
74
+ depth (int): depth of transformer
75
+ num_heads (int): number of attention heads
76
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
77
+ qkv_bias (bool): enable bias for qkv if True
78
+ proj_bias (bool): enable bias for proj in attn if True
79
+ ffn_bias (bool): enable bias for ffn if True
80
+ drop_path_rate (float): stochastic depth rate
81
+ drop_path_uniform (bool): apply uniform drop rate across blocks
82
+ weight_init (str): weight init scheme
83
+ init_values (float): layer-scale init values
84
+ embed_layer (nn.Module): patch embedding layer
85
+ act_layer (nn.Module): MLP activation layer
86
+ block_fn (nn.Module): transformer block class
87
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
88
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
89
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
90
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
91
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
92
+ """
93
+ super().__init__()
94
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
95
+
96
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
97
+ self.n_blocks = depth
98
+ self.num_heads = num_heads
99
+ self.patch_size = patch_size
100
+ self.num_register_tokens = num_register_tokens
101
+ self.interpolate_antialias = interpolate_antialias
102
+ self.interpolate_offset = interpolate_offset
103
+
104
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
105
+ num_patches = self.patch_embed.num_patches
106
+
107
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
108
+ assert num_register_tokens >= 0
109
+ self.register_tokens = (
110
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
111
+ )
112
+
113
+ if drop_path_uniform is True:
114
+ dpr = [drop_path_rate] * depth
115
+ else:
116
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
117
+
118
+ if ffn_layer == "mlp":
119
+ print("using MLP layer as FFN")
120
+ ffn_layer = Mlp
121
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
122
+ print("using SwiGLU layer as FFN")
123
+ ffn_layer = SwiGLUFFNFused
124
+ elif ffn_layer == "identity":
125
+ print("using Identity layer as FFN")
126
+
127
+ def f(*args, **kwargs):
128
+ return nn.Identity()
129
+
130
+ ffn_layer = f
131
+ else:
132
+ raise NotImplementedError
133
+
134
+ blocks_list = [
135
+ block_fn(
136
+ dim=embed_dim,
137
+ num_heads=num_heads,
138
+ mlp_ratio=mlp_ratio,
139
+ qkv_bias=qkv_bias,
140
+ proj_bias=proj_bias,
141
+ ffn_bias=ffn_bias,
142
+ drop_path=dpr[i],
143
+ norm_layer=norm_layer,
144
+ act_layer=act_layer,
145
+ ffn_layer=ffn_layer,
146
+ init_values=init_values,
147
+ )
148
+ for i in range(depth)
149
+ ]
150
+ if block_chunks > 0:
151
+ self.chunked_blocks = True
152
+ chunked_blocks = []
153
+ chunksize = depth // block_chunks
154
+ for i in range(0, depth, chunksize):
155
+ # this is to keep the block index consistent if we chunk the block list
156
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i: i + chunksize])
157
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
158
+ else:
159
+ self.chunked_blocks = False
160
+ self.blocks = nn.ModuleList(blocks_list)
161
+
162
+ self.norm = norm_layer(embed_dim)
163
+ self.head = nn.Identity()
164
+
165
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
166
+
167
+ self.init_weights()
168
+
169
+ def init_weights(self):
170
+ trunc_normal_(self.pos_embed, std=0.02)
171
+ if self.register_tokens is not None:
172
+ nn.init.normal_(self.register_tokens, std=1e-6)
173
+ named_apply(init_weights_vit_timm, self)
174
+
175
+ def interpolate_pos_encoding(self, x, w, h):
176
+ previous_dtype = x.dtype
177
+ npatch = x.shape[1] - 1
178
+ N = self.pos_embed.shape[1]
179
+ if npatch == N and w == h:
180
+ return self.pos_embed
181
+ patch_pos_embed = self.pos_embed.float()
182
+ dim = x.shape[-1]
183
+ w0 = w // self.patch_size
184
+ h0 = h // self.patch_size
185
+ # we add a small number to avoid floating point error in the interpolation
186
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
187
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
188
+
189
+ sqrt_N = math.sqrt(N)
190
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
191
+ patch_pos_embed = nn.functional.interpolate(
192
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
193
+ scale_factor=(sx, sy),
194
+ mode="bicubic",
195
+ antialias=self.interpolate_antialias,
196
+ )
197
+
198
+ assert int(w0) == patch_pos_embed.shape[-2]
199
+ assert int(h0) == patch_pos_embed.shape[-1]
200
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
201
+ return patch_pos_embed.to(previous_dtype)
202
+
203
+ def prepare_tokens_with_masks(self, x, masks=None):
204
+ B, nc, w, h = x.shape
205
+ x = self.patch_embed(x)
206
+ if masks is not None:
207
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
208
+
209
+ x = x + self.interpolate_pos_encoding(x, w, h)
210
+
211
+ if self.register_tokens is not None:
212
+ x = torch.cat(
213
+ (
214
+ x[:, :1],
215
+ self.register_tokens.expand(x.shape[0], -1, -1),
216
+ x[:, 1:],
217
+ ),
218
+ dim=1,
219
+ )
220
+
221
+ return x
222
+
223
+ def forward_features_list(self, x_list, masks_list):
224
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
225
+ for blk in self.blocks:
226
+ x = blk(x)
227
+
228
+ all_x = x
229
+ output = []
230
+ for x, masks in zip(all_x, masks_list):
231
+ x_norm = self.norm(x)
232
+ output.append(
233
+ {
234
+ "x_norm_clstoken": x_norm[:, 0],
235
+ "x_norm_regtokens": x_norm[:, 1: self.num_register_tokens + 1],
236
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1:],
237
+ "x_prenorm": x,
238
+ "masks": masks,
239
+ }
240
+ )
241
+ return output
242
+
243
+ def forward(self, x, masks=None):
244
+ if isinstance(x, list):
245
+ return self.forward_features_list(x, masks)
246
+
247
+ x = self.prepare_tokens_with_masks(x, masks)
248
+
249
+ for blk in self.blocks:
250
+ x = blk(x)
251
+
252
+ x_norm = self.norm(x)
253
+ return x_norm
254
+
255
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
256
+ x = self.prepare_tokens_with_masks(x)
257
+ # If n is an int, take the n last blocks. If it's a list, take them
258
+ output, total_block_len = [], len(self.blocks)
259
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
260
+ for i, blk in enumerate(self.blocks):
261
+ x = blk(x)
262
+ if i in blocks_to_take:
263
+ output.append(x)
264
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
265
+ return output
266
+
267
+ def _get_intermediate_layers_chunked(self, x, n=1):
268
+ x = self.prepare_tokens_with_masks(x)
269
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
270
+ # If n is an int, take the n last blocks. If it's a list, take them
271
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
272
+ for block_chunk in self.blocks:
273
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
274
+ x = blk(x)
275
+ if i in blocks_to_take:
276
+ output.append(x)
277
+ i += 1
278
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
279
+ return output
280
+
281
+ def get_intermediate_layers(
282
+ self,
283
+ x: torch.Tensor,
284
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
285
+ reshape: bool = False,
286
+ return_class_token: bool = False,
287
+ norm=True,
288
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
289
+ if self.chunked_blocks:
290
+ outputs = self._get_intermediate_layers_chunked(x, n)
291
+ else:
292
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
293
+ if norm:
294
+ outputs = [self.norm(out) for out in outputs]
295
+ class_tokens = [out[:, 0] for out in outputs]
296
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
297
+ if reshape:
298
+ B, _, w, h = x.shape
299
+ outputs = [
300
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
301
+ for out in outputs
302
+ ]
303
+ if return_class_token:
304
+ return tuple(zip(outputs, class_tokens))
305
+ return tuple(outputs)
306
+
307
+
308
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
309
+ """ViT weight initialization, original timm impl (for reproducibility)"""
310
+ if isinstance(module, nn.Linear):
311
+ trunc_normal_(module.weight, std=0.02)
312
+ if module.bias is not None:
313
+ nn.init.zeros_(module.bias)
314
+
315
+
316
+ class Encoder(nn.Module):
317
+ def __init__(self, model_type='small'):
318
+ super().__init__()
319
+ if model_type == 'tiny':
320
+ self.vit = DinoVisionTransformer(
321
+ img_size=256,
322
+ patch_size=16,
323
+ embed_dim=192,
324
+ depth=12,
325
+ num_heads=6,
326
+ mlp_ratio=4,
327
+ block_fn=partial(Block, attn_class=MemEffAttention),
328
+ num_register_tokens=0
329
+ )
330
+ path = "checkpoint/deit_tiny_patch16_224-a1311bcf.pth"
331
+
332
+ elif model_type == 'small':
333
+ self.vit = DinoVisionTransformer(
334
+ img_size=256,
335
+ patch_size=16,
336
+ embed_dim=384,
337
+ depth=12,
338
+ num_heads=6,
339
+ mlp_ratio=4,
340
+ block_fn=partial(Block, attn_class=MemEffAttention),
341
+ num_register_tokens=0
342
+ )
343
+ path = "checkpoint/dinov2_vits14_pretrain.pth"
344
+
345
+ else:
346
+ assert False, r'Encoder: check the vit model type'
347
+
348
+ state_dict = torch.load(path, map_location='cpu')['model'] \
349
+ if model_type == 'tiny' else torch.load(path, map_location='cpu')
350
+
351
+ for k in ['pos_embed', 'patch_embed.proj.weight']:
352
+ del state_dict[k]
353
+ msg = self.vit.load_state_dict(state_dict, strict=False)
354
+ print(' missing_keys:{},\n unexpected_keys:{}'.format(msg.missing_keys, msg.unexpected_keys))
355
+ print('model_type: {},\n checkpoint_path: {}'.format(model_type, path))
356
+
357
+ self.resnet = resnet18(pretrained=True)
358
+ self.drop = nn.Dropout(p=0.01)
359
+
360
+ # 新增特征融合模块
361
+ self.fusion_conv = nn.Sequential(
362
+ nn.Conv2d(512 + 384, 384, kernel_size=1), # 假设ViT embed_dim=384
363
+ nn.BatchNorm2d(384),
364
+ nn.ReLU(inplace=True)
365
+ )
366
+
367
+ def detail_capture(self, x):
368
+ x = self.resnet.conv1(x)
369
+ x = self.resnet.bn1(x)
370
+ x = self.resnet.relu(x)
371
+
372
+ x2 = self.drop(self.resnet.layer1(x))
373
+ x3 = self.resnet.layer2(x2)
374
+ x4 = self.resnet.layer3(x3)
375
+ x5 = self.resnet.layer4(x4)
376
+ return [x2, x3, x4, x5]
377
+
378
+ def forward(self, x, y):
379
+
380
+ v_x = self.vit(x)
381
+ v_y = self.vit(y)
382
+
383
+ v_x = rearrange(v_x, 'b (h w) c -> b c h w', h=16, w=16)
384
+ v_y = rearrange(v_y, 'b (h w) c -> b c h w', h=16, w=16)
385
+
386
+ c_x = self.detail_capture(x)
387
+ c_y = self.detail_capture(y)
388
+
389
+ fused_v_x = self.fusion_conv(torch.cat([c_x[-1], v_x], dim=1))
390
+ fused_v_y = self.fusion_conv(torch.cat([c_y[-1], v_y], dim=1))
391
+ return c_x[:-1] + [fused_v_x], c_y[:-1] + [fused_v_y]
model/metric_tool.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ ################### metrics ###################
5
+ class AverageMeter(object):
6
+ """Computes and stores the average and current value"""
7
+
8
+ def __init__(self):
9
+ self.initialized = False
10
+ self.val = None
11
+ self.avg = None
12
+ self.sum = None
13
+ self.count = None
14
+
15
+ def initialize(self, val, weight):
16
+ self.val = val
17
+ self.avg = val
18
+ self.sum = val * weight
19
+ self.count = weight
20
+ self.initialized = True
21
+
22
+ def update(self, val, weight=1):
23
+ if not self.initialized:
24
+ self.initialize(val, weight)
25
+ else:
26
+ self.add(val, weight)
27
+
28
+ def add(self, val, weight):
29
+ self.val = val
30
+ self.sum += val * weight
31
+ self.count += weight
32
+ self.avg = self.sum / self.count
33
+
34
+ def value(self):
35
+ return self.val
36
+
37
+ def average(self):
38
+ return self.avg
39
+
40
+ def get_scores(self):
41
+ scores_dict = cm2score(self.sum)
42
+ return scores_dict
43
+
44
+ def clear(self):
45
+ self.initialized = False
46
+
47
+
48
+ ################### cm metrics ###################
49
+ class ConfuseMatrixMeter(AverageMeter):
50
+ """Computes and stores the average and current value"""
51
+
52
+ def __init__(self, n_class):
53
+ super(ConfuseMatrixMeter, self).__init__()
54
+ self.n_class = n_class
55
+
56
+ def update_cm(self, pr, gt, weight=1):
57
+ """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵"""
58
+ val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr)
59
+ self.update(val, weight)
60
+ current_score = cm2F1(val)
61
+ return current_score
62
+
63
+ def get_scores(self):
64
+ scores_dict = cm2score(self.sum)
65
+ return scores_dict
66
+
67
+
68
+ def harmonic_mean(xs):
69
+ harmonic_mean = len(xs) / sum((x + 1e-6) ** -1 for x in xs)
70
+ return harmonic_mean
71
+
72
+
73
+ def cm2F1(confusion_matrix):
74
+ hist = confusion_matrix
75
+ tp = hist[1, 1]
76
+ fn = hist[1, 0]
77
+ fp = hist[0, 1]
78
+ tn = hist[0, 0]
79
+ # recall
80
+ recall = tp / (tp + fn + np.finfo(np.float32).eps)
81
+ # precision
82
+ precision = tp / (tp + fp + np.finfo(np.float32).eps)
83
+ # F1 score
84
+ f1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
85
+ return f1
86
+
87
+
88
+ def cm2score(confusion_matrix):
89
+ hist = confusion_matrix
90
+ tp = hist[1, 1]
91
+ fn = hist[1, 0]
92
+ fp = hist[0, 1]
93
+ tn = hist[0, 0]
94
+ # acc
95
+ oa = (tp + tn) / (tp + fn + fp + tn + np.finfo(np.float32).eps)
96
+ # recall
97
+ recall = tp / (tp + fn + np.finfo(np.float32).eps)
98
+ # precision
99
+ precision = tp / (tp + fp + np.finfo(np.float32).eps)
100
+ # F1 score
101
+ f1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
102
+ # IoU
103
+ iou = tp / (tp + fp + fn + np.finfo(np.float32).eps)
104
+ # pre
105
+ pre = ((tp + fn) * (tp + fp) + (tn + fp) * (tn + fn)) / (tp + fp + tn + fn) ** 2
106
+ # kappa
107
+ kappa = (oa - pre) / (1 - pre)
108
+ score_dict = {'Kappa': kappa, 'IoU': iou, 'F1': f1, 'OA': oa, 'recall': recall, 'precision': precision, 'Pre': pre}
109
+ return score_dict
110
+
111
+
112
+ def get_confuse_matrix(num_classes, label_gts, label_preds):
113
+ """计算一组预测的混淆矩阵"""
114
+
115
+ def __fast_hist(label_gt, label_pred):
116
+ """
117
+ Collect values for Confusion Matrix
118
+ For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
119
+ :param label_gt: <np.array> ground-truth
120
+ :param label_pred: <np.array> prediction
121
+ :return: <np.ndarray> values for confusion matrix
122
+ """
123
+ mask = (label_gt >= 0) & (label_gt < num_classes)
124
+ hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask],
125
+ minlength=num_classes ** 2).reshape(num_classes, num_classes)
126
+ return hist
127
+
128
+ confusion_matrix = np.zeros((num_classes, num_classes))
129
+ for lt, lp in zip(label_gts, label_preds):
130
+ confusion_matrix += __fast_hist(lt.flatten(), lp.flatten())
131
+ return confusion_matrix
model/resnet.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ import torch
4
+ import torch.utils.model_zoo as model_zoo
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10
+ 'resnet152']
11
+
12
+
13
+ model_urls = {
14
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19
+ }
20
+
21
+
22
+ def conv3x3(in_planes, out_planes, stride=1):
23
+ """3x3 convolution with padding"""
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25
+ padding=1, bias=False)
26
+
27
+
28
+
29
+
30
+ class BasicBlock(nn.Module):
31
+ expansion = 1
32
+
33
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
34
+ super(BasicBlock, self).__init__()
35
+ self.conv1 = conv3x3(inplanes, planes, stride)
36
+ self.bn1 = nn.BatchNorm2d(planes)
37
+ self.relu = nn.ReLU(inplace=True)
38
+ self.conv2 = conv3x3(planes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.downsample = downsample
41
+ self.stride = stride
42
+
43
+ def forward(self, x):
44
+ residual = x
45
+
46
+ out = self.conv1(x)
47
+ out = self.bn1(out)
48
+ out = self.relu(out)
49
+
50
+ out = self.conv2(out)
51
+ out = self.bn2(out)
52
+
53
+ if self.downsample is not None:
54
+ residual = self.downsample(x)
55
+
56
+ out += residual
57
+ out = self.relu(out)
58
+
59
+ return out
60
+
61
+
62
+ class Bottleneck(nn.Module):
63
+ expansion = 4
64
+
65
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68
+ self.bn1 = nn.BatchNorm2d(planes)
69
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70
+ padding=1, bias=False)
71
+ self.bn2 = nn.BatchNorm2d(planes)
72
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
73
+ self.bn3 = nn.BatchNorm2d(planes * 4)
74
+ self.relu = nn.ReLU(inplace=True)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ residual = x
80
+
81
+ out = self.conv1(x)
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv3(out)
90
+ out = self.bn3(out)
91
+
92
+ if self.downsample is not None:
93
+ residual = self.downsample(x)
94
+
95
+ out += residual
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class ResNet(nn.Module):
102
+
103
+ def __init__(self, block, layers, num_classes=1000):
104
+ self.inplanes = 64
105
+ super(ResNet, self).__init__()
106
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
107
+ bias=False)
108
+ self.bn1 = nn.BatchNorm2d(64)
109
+ self.relu = nn.ReLU(inplace=True)
110
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
111
+ self.layer1 = self._make_layer(block, 64, layers[0])
112
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115
+ self.avgpool = nn.AvgPool2d(7, stride=1)
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ def _make_layer(self, block, planes, blocks, stride=1):
119
+ downsample = None
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ nn.Conv2d(self.inplanes, planes * block.expansion,
123
+ kernel_size=1, stride=stride, bias=False),
124
+ nn.BatchNorm2d(planes * block.expansion),
125
+ )
126
+
127
+ layers = []
128
+ layers.append(block(self.inplanes, planes, stride, downsample))
129
+ self.inplanes = planes * block.expansion
130
+ for i in range(1, blocks):
131
+ layers.append(block(self.inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ x = self.conv1(x)
137
+ x = self.bn1(x)
138
+ x = self.relu(x)
139
+ x = self.maxpool(x)
140
+
141
+ x = self.layer1(x)
142
+ x = self.layer2(x)
143
+ x = self.layer3(x)
144
+ x = self.layer4(x)
145
+
146
+ x = self.avgpool(x)
147
+ x = x.view(x.size(0), -1)
148
+ x = self.fc(x)
149
+
150
+ return x
151
+
152
+
153
+ def resnet18(pretrained=False, **kwargs):
154
+ """Constructs a ResNet-18 model.
155
+ Args:
156
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
157
+ """
158
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
159
+ if pretrained:
160
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
161
+ return model
162
+
163
+
164
+ def resnet34(pretrained=False, **kwargs):
165
+ """Constructs a ResNet-34 model.
166
+ Args:
167
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
168
+ """
169
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170
+ if pretrained:
171
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
172
+ return model
173
+
174
+
175
+ def resnet50(pretrained=False, **kwargs):
176
+ """Constructs a ResNet-50 model.
177
+ Args:
178
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
179
+ """
180
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
181
+ if pretrained:
182
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
183
+ return model
184
+
185
+
186
+ def resnet101(pretrained=False, **kwargs):
187
+ """Constructs a ResNet-101 model.
188
+ Args:
189
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
190
+ """
191
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
192
+ if pretrained:
193
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
194
+ return model
195
+
196
+
197
+ def resnet152(pretrained=False, **kwargs):
198
+ """Constructs a ResNet-152 model.
199
+ Args:
200
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
201
+ """
202
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
203
+ if pretrained:
204
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
205
+ return model
206
+
207
+
208
+ if __name__ == '__main__':
209
+ m = resnet18(pretrained=True, vit_dim=768)
210
+ x = torch.rand(1, 3, 256, 256)
211
+ vit = [torch.rand(1, 256, 768), torch.rand(1, 256, 768), torch.rand(1, 256, 768)]
212
+ x2, x3, x4 = m(x, vit)
213
+ print(x2.shape, x3.shape, x4.shape)
model/trainer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.encoder import Encoder
5
+ from model.decoder import Decoder
6
+
7
+ from model.utils import weight_init
8
+
9
+
10
+ class Trainer(nn.Module):
11
+ def __init__(self, model_type='small'):
12
+ super().__init__()
13
+ if model_type == 'tiny':
14
+ embed_dim = 192
15
+ elif model_type == 'small':
16
+ embed_dim = 384
17
+ else:
18
+ assert False, r'Trainer: check the vit model type'
19
+
20
+ self.encoder = Encoder(model_type)
21
+
22
+ self.decoder = Decoder(in_dim=[64, 128, 256, embed_dim])
23
+ weight_init(self.decoder)
24
+
25
+ def forward(self, x, y):
26
+ fx, fy = self.encoder(x, y)
27
+ pred = self.decoder(fx, fy)
28
+
29
+ return pred
30
+
model/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import random
6
+
7
+
8
+ def weight_init(module):
9
+ for n, m in module.named_children():
10
+ print('initialize: '+n)
11
+ if isinstance(m, nn.Conv2d):
12
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
13
+ if m.bias is not None:
14
+ nn.init.zeros_(m.bias)
15
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
16
+ nn.init.ones_(m.weight)
17
+ if m.bias is not None:
18
+ nn.init.zeros_(m.bias)
19
+ elif isinstance(m, nn.Linear):
20
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
21
+ if m.bias is not None:
22
+ nn.init.zeros_(m.bias)
23
+ elif isinstance(m, nn.Sequential):
24
+ for f, g in m.named_children():
25
+ print('initialize: ' + f)
26
+ if isinstance(g, nn.Conv2d):
27
+ nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu')
28
+ if g.bias is not None:
29
+ nn.init.zeros_(g.bias)
30
+ elif isinstance(g, (nn.BatchNorm2d, nn.GroupNorm)):
31
+ nn.init.ones_(g.weight)
32
+ if g.bias is not None:
33
+ nn.init.zeros_(g.bias)
34
+ elif isinstance(g, nn.Linear):
35
+ nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu')
36
+ if g.bias is not None:
37
+ nn.init.zeros_(g.bias)
38
+ elif isinstance(m, nn.AdaptiveAvgPool2d) or isinstance(m, nn.AdaptiveMaxPool2d) or isinstance(m, nn.ModuleList) or isinstance(m, nn.BCELoss):
39
+ a=1
40
+ else:
41
+ pass
42
+
43
+
44
+ def init_seed(seed):
45
+ torch.manual_seed(seed)
46
+ torch.cuda.manual_seed(seed)
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+
50
+
51
+ def BCEDiceLoss(inputs, targets):
52
+ # print(inputs.shape, targets.shape)
53
+ bce = F.binary_cross_entropy(inputs, targets)
54
+ inter = (inputs * targets).sum()
55
+ eps = 1e-5
56
+ dice = (2 * inter + eps) / (inputs.sum() + targets.sum() + eps)
57
+ # print(bce.item(), inter.item(), inputs.sum().item(), dice.item())
58
+ return bce + 1 - dice
59
+
60
+
61
+ def BCE(inputs, targets):
62
+ # print(inputs.shape, targets.shape)
63
+ bce = F.binary_cross_entropy(inputs, targets)
64
+ return bce
65
+
66
+
67
+ def adjust_learning_rate(args, optimizer, epoch, iter, max_batches, lr_factor=1):
68
+ if args.lr_mode == 'step':
69
+ lr = args.lr * (0.1 ** (epoch // args.step_loss))
70
+ elif args.lr_mode == 'poly':
71
+ cur_iter = iter
72
+ max_iter = max_batches * args.max_epochs
73
+ lr = args.lr * (1 - cur_iter * 1.0 / max_iter) ** 0.9
74
+ else:
75
+ raise ValueError('Unknown lr mode {}'.format(args.lr_mode))
76
+ if epoch == 0 and iter < 200:
77
+ lr = args.lr * 0.9 * (iter + 1) / 200 + 0.1 * args.lr # warm_up
78
+ lr *= lr_factor
79
+ for param_group in optimizer.param_groups:
80
+ param_group['lr'] = lr
81
+ return lr