Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Based on fairseq code bases | |
| # https://github.com/facebookresearch/fairseq | |
| # -------------------------------------------------------- | |
| import torch | |
| class RelativePositionalEncoding(torch.nn.Module): | |
| def __init__(self, d_model, maxlen=1000, embed_v=False): | |
| super(RelativePositionalEncoding, self).__init__() | |
| self.d_model = d_model | |
| self.maxlen = maxlen | |
| self.pe_k = torch.nn.Embedding(2*maxlen, d_model) | |
| if embed_v: | |
| self.pe_v = torch.nn.Embedding(2*maxlen, d_model) | |
| self.embed_v = embed_v | |
| def forward(self, pos_seq, incremental_state=None): | |
| pos_seq[pos_seq < -self.maxlen] = -self.maxlen | |
| pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1 | |
| pos_seq = pos_seq + self.maxlen | |
| if incremental_state is not None: | |
| pos_seq = pos_seq[-1:] | |
| if self.embed_v: | |
| return self.pe_k(pos_seq), self.pe_v(pos_seq) | |
| else: | |
| return self.pe_k(pos_seq), None | |