| import torch | |
| import torch.nn as nn | |
| from model.encoder import Encoder | |
| from model.decoder import Decoder | |
| from model.utils import weight_init | |
| class Trainer(nn.Module): | |
| def __init__(self, model_type='small'): | |
| super().__init__() | |
| if model_type == 'tiny': | |
| embed_dim = 192 | |
| elif model_type == 'small': | |
| embed_dim = 384 | |
| else: | |
| assert False, r'Trainer: check the vit model type' | |
| self.encoder = Encoder(model_type) | |
| self.decoder = Decoder(in_dim=[64, 128, 256, embed_dim]) | |
| weight_init(self.decoder) | |
| def forward(self, x, y): | |
| fx, fy = self.encoder(x, y) | |
| pred = self.decoder(fx, fy) | |
| return pred | |