yourusername commited on
Commit
e39e45c
1 Parent(s): 5fff95e

:construction: wip

Browse files
Files changed (3) hide show
  1. config.json +1 -0
  2. model.py +23 -8
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"input_dim": 784, "hidden_dims": [256, 64, 16, 4, 2]}
model.py CHANGED
@@ -1,8 +1,11 @@
 
1
  from pathlib import Path
2
- from typing import Tuple
 
3
  import torch
4
- from torch import nn
5
  import torch.nn.functional as F
 
 
6
 
7
 
8
  class Dense(nn.Module):
@@ -42,14 +45,21 @@ class Decoder(nn.Module):
42
  return self.layers(x)
43
 
44
 
45
- class Autoencoder(nn.Module):
 
 
 
 
46
 
47
- def __init__(self, input_dim: int = 784, hidden_dims: Tuple[int] = (256, 64, 16, 4, 2)):
 
 
48
  super().__init__()
49
- self.encoder = Encoder(input_dim, *hidden_dims)
50
- self.decoder = Decoder(input_dim, *reversed(hidden_dims))
51
- self.input_dim = input_dim
52
- self.hidden_dims = hidden_dims
 
53
 
54
  def forward(self, x):
55
  x = x.flatten(1)
@@ -58,6 +68,11 @@ class Autoencoder(nn.Module):
58
  loss = F.mse_loss(recon, x)
59
  return recon, latent, loss
60
 
 
 
 
 
 
61
 
62
  class MessageModel:
63
 
1
+ from dataclasses import dataclass
2
  from pathlib import Path
3
+ from typing import Union, List, Tuple
4
+
5
  import torch
 
6
  import torch.nn.functional as F
7
+ from huggingface_hub import ModelHubMixin
8
+ from torch import nn
9
 
10
 
11
  class Dense(nn.Module):
45
  return self.layers(x)
46
 
47
 
48
+ @dataclass
49
+ class AutoencoderConfig:
50
+ input_dim: int = 784
51
+ hidden_dims: Union[Tuple[str], List[str]] = (256, 64, 16, 4, 2)
52
+
53
 
54
+ class Autoencoder(nn.Module, ModelHubMixin):
55
+
56
+ def __init__(self, config: Union[dict, AutoencoderConfig] = AutoencoderConfig(), **kwargs):
57
  super().__init__()
58
+ self.config = AutoencoderConfig(**config) if isinstance(config, dict) else config
59
+ self.config.__dict__.update(**kwargs)
60
+
61
+ self.encoder = Encoder(self.config.input_dim, *self.config.hidden_dims)
62
+ self.decoder = Decoder(self.config.input_dim, *reversed(self.config.hidden_dims))
63
 
64
  def forward(self, x):
65
  x = x.flatten(1)
68
  loss = F.mse_loss(recon, x)
69
  return recon, latent, loss
70
 
71
+ def save_pretrained(self, save_directory, **kwargs):
72
+ assert 'config' not in kwargs, \
73
+ "save_pretrained handles passing model config for you, please dont pass it"
74
+ super().save_pretrained(save_directory, config=self.config.__dict__, **kwargs)
75
+
76
 
77
  class MessageModel:
78
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e550b61b21a7ffdae47ee0f22740aadf94053d00d129af9e7764a0e547b9f17a
3
+ size 1758947