|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import random |
|
|
|
|
|
|
|
|
def weight_init(module): |
|
|
for n, m in module.named_children(): |
|
|
print('initialize: '+n) |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
|
|
nn.init.ones_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Sequential): |
|
|
for f, g in m.named_children(): |
|
|
print('initialize: ' + f) |
|
|
if isinstance(g, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu') |
|
|
if g.bias is not None: |
|
|
nn.init.zeros_(g.bias) |
|
|
elif isinstance(g, (nn.BatchNorm2d, nn.GroupNorm)): |
|
|
nn.init.ones_(g.weight) |
|
|
if g.bias is not None: |
|
|
nn.init.zeros_(g.bias) |
|
|
elif isinstance(g, nn.Linear): |
|
|
nn.init.kaiming_normal_(g.weight, mode='fan_in', nonlinearity='relu') |
|
|
if g.bias is not None: |
|
|
nn.init.zeros_(g.bias) |
|
|
elif isinstance(m, nn.AdaptiveAvgPool2d) or isinstance(m, nn.AdaptiveMaxPool2d) or isinstance(m, nn.ModuleList) or isinstance(m, nn.BCELoss): |
|
|
a=1 |
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
def init_seed(seed): |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
|
|
|
def BCEDiceLoss(inputs, targets): |
|
|
|
|
|
bce = F.binary_cross_entropy(inputs, targets) |
|
|
inter = (inputs * targets).sum() |
|
|
eps = 1e-5 |
|
|
dice = (2 * inter + eps) / (inputs.sum() + targets.sum() + eps) |
|
|
|
|
|
return bce + 1 - dice |
|
|
|
|
|
|
|
|
def BCE(inputs, targets): |
|
|
|
|
|
bce = F.binary_cross_entropy(inputs, targets) |
|
|
return bce |
|
|
|
|
|
|
|
|
def adjust_learning_rate(args, optimizer, epoch, iter, max_batches, lr_factor=1): |
|
|
if args.lr_mode == 'step': |
|
|
lr = args.lr * (0.1 ** (epoch // args.step_loss)) |
|
|
elif args.lr_mode == 'poly': |
|
|
cur_iter = iter |
|
|
max_iter = max_batches * args.max_epochs |
|
|
lr = args.lr * (1 - cur_iter * 1.0 / max_iter) ** 0.9 |
|
|
else: |
|
|
raise ValueError('Unknown lr mode {}'.format(args.lr_mode)) |
|
|
if epoch == 0 and iter < 200: |
|
|
lr = args.lr * 0.9 * (iter + 1) / 200 + 0.1 * args.lr |
|
|
lr *= lr_factor |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
return lr |
|
|
|