himanshu-skid19 zombie-596 commited on
Commit
7446b5a
·
1 Parent(s): d659f17

Update app.py (#4)

Browse files

- Update app.py (61c9a91583e8d3519a0717074031cfbb747073fc)


Co-authored-by: Saptarshi Mukherjee <zombie-596@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +118 -4
app.py CHANGED
@@ -1,5 +1,119 @@
1
- pip install transformers
2
- model_id = "himanshu-skid19/new_linear_model_1090.pt"
 
 
 
 
 
3
 
4
- tokenizer = AutoTokenizer.from_pretrained(model_id)
5
- model = AutoModel.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image, ImageOps
3
+ import torch
4
+ from matplotlib.image import imread
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ import math
8
 
9
+ class Block(nn.Module):
10
+ def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
11
+ super().__init__()
12
+ self.time_mlp = nn.Linear(time_emb_dim, out_ch)
13
+ if up:
14
+ self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
15
+ self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
16
+ self.Upsample = nn.Upsample(scale_factor = 2, mode ='bilinear')
17
+
18
+ else:
19
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
20
+ self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
21
+ self.maxpool = nn.MaxPool2d(4, 2, 1)
22
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
23
+ self.bnorm1 = nn.BatchNorm2d(out_ch)
24
+ self.bnorm2 = nn.BatchNorm2d(out_ch)
25
+ self.silu = nn.SiLU()
26
+ self.relu = nn.ReLU()
27
+
28
+ def forward(self, x, t, ):
29
+ # First Conv
30
+ h = (self.silu(self.bnorm1(self.conv1(x))))
31
+ # Time embedding
32
+ time_emb = self.relu(self.time_mlp(t))
33
+ # Extend last 2 dimensions
34
+ time_emb = time_emb[(..., ) + (None, ) * 2]
35
+ # Add time channel
36
+ h = h + time_emb
37
+ # Second Conv
38
+ h = (self.silu(self.bnorm2(self.conv2(h))))
39
+ # Down or Upsample
40
+ return self.transform(h)
41
+
42
+
43
+ class SinusoidalPositionEmbeddings(nn.Module):
44
+ def __init__(self, dim):
45
+ super().__init__()
46
+ self.dim = dim
47
+
48
+ def forward(self, time):
49
+ device = time.device
50
+ half_dim = self.dim // 2
51
+ embeddings = math.log(10000) / (half_dim - 1)
52
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
53
+ embeddings = time[:, None] * embeddings[None, :]
54
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
55
+ # TODO: Double check the ordering here
56
+ return embeddings
57
+
58
+
59
+ class SimpleUnet(nn.Module):
60
+ """
61
+ A simplified variant of the Unet architecture.
62
+ """
63
+ def __init__(self):
64
+ super().__init__()
65
+ image_channels = 3
66
+ down_channels = (32, 64, 128, 256, 512)
67
+ up_channels = (512, 256, 128, 64, 32)
68
+ out_dim = 3
69
+ time_emb_dim = 32
70
+
71
+ # Time embedding
72
+ self.time_mlp = nn.Sequential(
73
+ SinusoidalPositionEmbeddings(time_emb_dim),
74
+ nn.Linear(time_emb_dim, time_emb_dim),
75
+ nn.ReLU()
76
+ )
77
+
78
+ # Initial projection
79
+ self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
80
+
81
+ # Downsample
82
+ self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
83
+ time_emb_dim) \
84
+ for i in range(len(down_channels)-1)])
85
+ # Upsample
86
+ self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
87
+ time_emb_dim, up=True) \
88
+ for i in range(len(up_channels)-1)])
89
+
90
+ # Edit: Corrected a bug found by Jakub C (see YouTube comment)
91
+ self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
92
+
93
+ def forward(self, x, timestep):
94
+ # Embedd time
95
+ t = self.time_mlp(timestep)
96
+ # Initial conv
97
+ x = self.conv0(x)
98
+ # Unet
99
+ residual_inputs = []
100
+ for down in self.downs:
101
+ x = down(x, t)
102
+ residual_inputs.append(x)
103
+ for up in self.ups:
104
+ residual_x = residual_inputs.pop()
105
+ # Add residual x as additional channels
106
+ x = torch.cat((x, residual_x), dim=1)
107
+ x = up(x, t)
108
+ return self.output(x)
109
+
110
+
111
+ model = SimpleUnet()
112
+
113
+ st.title("Generatig images using a diffusion model")
114
+ model.load_state_dict(torch.load("new_linear_model_1090.pt"))
115
+
116
+ result = st.button("Click to generate image")
117
+
118
+ if(result):
119
+ model()