Kevin Li commited on
Commit
a13d3e2
·
verified ·
1 Parent(s): a872549

Upload folder using huggingface_hub

Browse files
__pycache__/models.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
decoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73c42a41e9a9bf115ea236aa5dfe7690d07123270c362217d0adc4d18ce2a4e8
3
+ size 802465833
decoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc47dc8c28c257501ee1e843d76772fca1b803c25d196b122d65b549967c22c9
3
+ size 2406851328
decoder_earlystop_55.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db2cfe59b08c6b05178d213901e71c07e8d843318caaf198ccba1661283372f9
3
+ size 802465833
decoder_earlystop_82.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9656b7cb181a8a6cc3af9fe9381a5137fe6249e2e03cf27abfce79bc7062db9c
3
+ size 802465833
models.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Decoder(nn.Module):
5
+ def __init__(self, input_dim, hidden_dim, gamma=0.1):
6
+ super().__init__()
7
+ self.input_dim = input_dim
8
+ self.hidden_dim = hidden_dim
9
+ self.gamma = gamma
10
+ self.float()
11
+
12
+ #should be 512, 1024
13
+ self.fc = nn.Sequential(
14
+ nn.Linear(input_dim, hidden_dim),
15
+ nn.BatchNorm1d(hidden_dim),
16
+ nn.ReLU(),
17
+ nn.Linear(hidden_dim, hidden_dim * 2),
18
+ nn.BatchNorm1d(hidden_dim * 2),
19
+ nn.ReLU(),
20
+ nn.Linear(hidden_dim * 2, hidden_dim * 4),
21
+ nn.BatchNorm1d(hidden_dim * 4),
22
+ nn.ReLU(),
23
+ nn.Linear(hidden_dim * 4, hidden_dim * 8),
24
+ nn.BatchNorm1d(hidden_dim * 8),
25
+ nn.ReLU(),
26
+ nn.Linear(hidden_dim * 8, hidden_dim * 4 * 4),
27
+ nn.BatchNorm1d(hidden_dim * 4 * 4),
28
+ nn.ReLU()
29
+ )
30
+
31
+ self.decoder = nn.Sequential(
32
+ nn.ConvTranspose2d(1024, 768, kernel_size=4, stride=2, padding=1),
33
+ nn.BatchNorm2d(768),
34
+ nn.ReLU(),
35
+ nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=1),
36
+ nn.BatchNorm2d(512),
37
+ nn.ReLU(),
38
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
39
+ nn.BatchNorm2d(256),
40
+ nn.ReLU(),
41
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
42
+ nn.BatchNorm2d(128),
43
+ nn.ReLU(),
44
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
45
+ nn.BatchNorm2d(64),
46
+ nn.ReLU(),
47
+ nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
48
+ nn.BatchNorm2d(32),
49
+ nn.ReLU(),
50
+ nn.Conv2d(32, 3, kernel_size=3, padding=1),
51
+ nn.Sigmoid()
52
+ )
53
+
54
+ def forward(self, z):
55
+ batch_size = z.shape[0]
56
+ # adding noise to inputs
57
+ gamma = 0.05
58
+ z = z + self.gamma * torch.randn_like(z)
59
+ z = self.fc(z)
60
+ z = z.view(batch_size, 1024, 4, 4)
61
+ return self.decoder(z)
62
+
63
+ def get_loss(self, emb, x):
64
+ x_hat = self.forward(emb)
65
+ l = nn.MSELoss(reduction="mean")
66
+ loss = l(x_hat, x)
67
+ return loss
68
+
69
+ @torch.no_grad()
70
+ def sample(self, samples, device):
71
+ samples = samples.to(device)
72
+ x_hat = self.forward(samples)
73
+
74
+ return x_hat