HA2F / model /trainer.py
InPeerReview's picture
Upload 6 files
2ff0f4b verified
import torch
import torch.nn as nn
from model.encoder import Encoder
from model.decoder import Decoder
from model.utils import weight_init
class Trainer(nn.Module):
def __init__(self, model_type='small'):
super().__init__()
if model_type == 'tiny':
embed_dim = 192
elif model_type == 'small':
embed_dim = 384
else:
assert False, r'Trainer: check the vit model type'
self.encoder = Encoder(model_type)
self.decoder = Decoder(in_dim=[64, 128, 256, embed_dim])
weight_init(self.decoder)
def forward(self, x, y):
fx, fy = self.encoder(x, y)
pred = self.decoder(fx, fy)
return pred