| import torch | |
| import torch.nn as nn | |
| from collections import Counter | |
| class BeastTokenizer: | |
| def __init__(self, texts=[], vocab_size=5000): | |
| self.word2idx = {'<PAD>': 0, '<UNK>': 1} | |
| if texts: | |
| counter = Counter(word for text in texts for word in text.split()) | |
| common = counter.most_common(vocab_size - 2) | |
| self.word2idx.update({word: idx + 2 for idx, (word, _) in enumerate(common)}) | |
| def encode(self, text, max_len=100): | |
| tokens = [self.word2idx.get(word, 1) for word in text.split()] | |
| return tokens[:max_len] + [0] * (max_len - len(tokens)) | |
| class BeastSpamModel(nn.Module): | |
| def __init__(self, vocab_size, embed_dim=128, hidden_dim=64): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) | |
| self.conv = nn.Conv1d(embed_dim, 128, kernel_size=5, padding=2) | |
| self.lstm = nn.LSTM(128, hidden_dim, batch_first=True, bidirectional=True) | |
| self.fc = nn.Linear(hidden_dim * 2, 1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| x = x.permute(0, 2, 1) | |
| x = self.conv(x).permute(0, 2, 1) | |
| lstm_out, _ = self.lstm(x) | |
| out = self.fc(lstm_out[:, -1, :]) | |
| return self.sigmoid(out).squeeze(1) | |