HA2F / model /utils.py
InPeerReview's picture
Upload 6 files
2ff0f4b verified
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import random
def weight_init(module):
for n, m in module.named_children():
print('initialize: '+n)
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Sequential):
for f, g in m.named_children():
print('initialize: ' + f)
if isinstance(g, nn.Conv2d):
nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu')
if g.bias is not None:
nn.init.zeros_(g.bias)
elif isinstance(g, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(g.weight)
if g.bias is not None:
nn.init.zeros_(g.bias)
elif isinstance(g, nn.Linear):
nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu')
if g.bias is not None:
nn.init.zeros_(g.bias)
elif isinstance(m, nn.AdaptiveAvgPool2d) or isinstance(m, nn.AdaptiveMaxPool2d) or isinstance(m, nn.ModuleList) or isinstance(m, nn.BCELoss):
a=1
else:
pass
def init_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def BCEDiceLoss(inputs, targets):
# print(inputs.shape, targets.shape)
bce = F.binary_cross_entropy(inputs, targets)
inter = (inputs * targets).sum()
eps = 1e-5
dice = (2 * inter + eps) / (inputs.sum() + targets.sum() + eps)
# print(bce.item(), inter.item(), inputs.sum().item(), dice.item())
return bce + 1 - dice
def BCE(inputs, targets):
# print(inputs.shape, targets.shape)
bce = F.binary_cross_entropy(inputs, targets)
return bce
def adjust_learning_rate(args, optimizer, epoch, iter, max_batches, lr_factor=1):
if args.lr_mode == 'step':
lr = args.lr * (0.1 ** (epoch // args.step_loss))
elif args.lr_mode == 'poly':
cur_iter = iter
max_iter = max_batches * args.max_epochs
lr = args.lr * (1 - cur_iter * 1.0 / max_iter) ** 0.9
else:
raise ValueError('Unknown lr mode {}'.format(args.lr_mode))
if epoch == 0 and iter < 200:
lr = args.lr * 0.9 * (iter + 1) / 200 + 0.1 * args.lr # warm_up
lr *= lr_factor
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr