lazerkat commited on
Commit
3192df2
·
verified ·
1 Parent(s): 1748b4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -78
app.py CHANGED
@@ -6,11 +6,137 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  import numpy as np
 
9
 
10
  # ============================================================================
11
- # DIFFUSION Model Architecture
12
  # ============================================================================
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Diffusion:
15
  def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
16
  self.timesteps = timesteps
@@ -20,14 +146,16 @@ class Diffusion:
20
  self.alpha_bars = torch.cumprod(self.alphas, dim=0)
21
 
22
  @torch.no_grad()
23
- def sample(self, model, x, steps=None):
24
  model.eval()
25
  if steps is None:
26
  steps = self.timesteps
27
 
 
 
28
  for t in reversed(range(steps)):
29
  t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
30
- predicted_noise = model(x, t_batch)
31
 
32
  alpha = self.alphas[t]
33
  alpha_bar = self.alpha_bars[t]
@@ -45,103 +173,77 @@ class Diffusion:
45
  return x
46
 
47
 
48
- class UNet(nn.Module):
49
- def __init__(self, in_channels=3, out_channels=3):
50
- super().__init__()
51
-
52
- # Encoder
53
- self.enc1 = self.conv_block(in_channels, 64)
54
- self.enc2 = self.conv_block(64, 128)
55
- self.enc3 = self.conv_block(128, 256)
56
-
57
- # Bottleneck
58
- self.bottleneck = self.conv_block(256, 512)
59
-
60
- # Decoder
61
- self.dec3 = self.conv_block(512 + 256, 256)
62
- self.dec2 = self.conv_block(256 + 128, 128)
63
- self.dec1 = self.conv_block(128 + 64, 64)
64
-
65
- # Time embedding
66
- self.time_embed = nn.Sequential(
67
- nn.Linear(1, 128),
68
- nn.ReLU(),
69
- nn.Linear(128, 128)
70
- )
71
-
72
- self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
73
- self.final = nn.Conv2d(64, out_channels, 1)
74
-
75
- self.pool = nn.MaxPool2d(2)
76
-
77
- def conv_block(self, in_ch, out_ch):
78
- return nn.Sequential(
79
- nn.Conv2d(in_ch, out_ch, 3, padding=1),
80
- nn.BatchNorm2d(out_ch),
81
- nn.ReLU(inplace=True),
82
- nn.Conv2d(out_ch, out_ch, 3, padding=1),
83
- nn.BatchNorm2d(out_ch),
84
- nn.ReLU(inplace=True)
85
- )
86
-
87
- def forward(self, x, t):
88
- # Time embedding
89
- t_embed = self.time_embed(t.float().unsqueeze(-1))
90
- t_embed = t_embed.unsqueeze(-1).unsqueeze(-1)
91
-
92
- # Encoder
93
- e1 = self.enc1(x)
94
- e2 = self.enc2(self.pool(e1))
95
- e3 = self.enc3(self.pool(e2))
96
-
97
- # Bottleneck
98
- b = self.bottleneck(self.pool(e3))
99
- b = b + t_embed.repeat(1, 1, b.shape[2], b.shape[3]) if b.shape[1] == t_embed.shape[1] else b
100
-
101
- # Decoder
102
- d3 = self.dec3(torch.cat([self.up(b), e3], dim=1))
103
- d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1))
104
- d1 = self.dec1(torch.cat([self.up(d2), e1], dim=1))
105
-
106
- return self.final(d1)
107
-
108
-
109
  # Global variables
110
  model = None
111
  device = None
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Download and load model
114
  def initialize_model():
115
- global model, device
116
 
117
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
118
 
 
119
  model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
120
  model_path = "newest.pth"
121
 
122
- if not os.path.exists(model_path):
123
- urllib.request.urlretrieve(model_url, model_path)
124
 
 
125
  checkpoint = torch.load(model_path, map_location=device)
126
 
127
- model = UNet().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  model.load_state_dict(checkpoint['model_state_dict'])
129
  model.eval()
130
 
131
- return "Model loaded successfully!"
 
 
 
 
 
 
 
 
 
 
132
 
133
  # Generate image
134
- def generate_image():
135
- global model, device
136
 
137
- if model is None:
138
  return None
139
 
140
- diffusion = Diffusion(timesteps=1000, device=device)
141
 
142
  with torch.no_grad():
143
- noise = torch.randn(1, 3, 64, 64).to(device)
144
- generated = diffusion.sample(model, noise, steps=100)
145
 
146
  # Convert to image
147
  image = generated.cpu().squeeze(0)
@@ -153,24 +255,34 @@ def generate_image():
153
  return Image.fromarray(image)
154
 
155
  # Create interface
156
- with gr.Blocks(title="RandomDiffusion") as demo:
157
  gr.Markdown("# 🎨 RandomDiffusion")
158
- gr.Markdown("Random image generation using diffusion")
159
 
160
  status = gr.Textbox(label="Status", value="Loading model...", interactive=False)
161
 
162
  with gr.Row():
163
- generate_btn = gr.Button("Generate Random Image", variant="primary")
 
 
 
 
 
 
 
164
 
165
  output_image = gr.Image(label="Generated Image", type="pil")
166
 
 
167
  demo.load(
168
  lambda: initialize_model(),
169
  outputs=[status]
170
  )
171
 
 
172
  generate_btn.click(
173
  generate_image,
 
174
  outputs=[output_image]
175
  )
176
 
 
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  import numpy as np
9
+ import json
10
 
11
  # ============================================================================
12
+ # DIFFUSION Model Architecture (from your training code)
13
  # ============================================================================
14
 
15
+ class TextEncoder(nn.Module):
16
+ def __init__(self, vocab_size, embed_dim=256, hidden_dim=512):
17
+ super().__init__()
18
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
19
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
20
+ self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
21
+
22
+ def forward(self, x):
23
+ embedded = self.embedding(x)
24
+ lstm_out, (hidden, _) = self.lstm(embedded)
25
+ hidden_forward = hidden[-2, :, :]
26
+ hidden_backward = hidden[-1, :, :]
27
+ combined = torch.cat([hidden_forward, hidden_backward], dim=1)
28
+ return self.fc(combined)
29
+
30
+
31
+ class DownBlock(nn.Module):
32
+ def __init__(self, in_channels, out_channels, time_emb_dim=256, text_emb_dim=512):
33
+ super().__init__()
34
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
35
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
36
+ self.norm1 = nn.BatchNorm2d(out_channels)
37
+ self.norm2 = nn.BatchNorm2d(out_channels)
38
+ self.time_mlp = nn.Sequential(
39
+ nn.Linear(time_emb_dim, out_channels), nn.SiLU(),
40
+ nn.Linear(out_channels, out_channels)
41
+ )
42
+ self.text_mlp = nn.Sequential(
43
+ nn.Linear(text_emb_dim, out_channels), nn.SiLU(),
44
+ nn.Linear(out_channels, out_channels)
45
+ )
46
+ self.pool = nn.MaxPool2d(2)
47
+
48
+ def forward(self, x, t_emb, text_emb):
49
+ h = self.conv1(x)
50
+ h = self.norm1(h)
51
+ t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
52
+ txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1)
53
+ h = h + t + txt
54
+ h = F.relu(h)
55
+ h = self.conv2(h)
56
+ h = self.norm2(h)
57
+ h = F.relu(h)
58
+ return h, self.pool(h)
59
+
60
+
61
+ class UpBlock(nn.Module):
62
+ def __init__(self, in_channels, skip_channels, out_channels, time_emb_dim=256, text_emb_dim=512):
63
+ super().__init__()
64
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
65
+ self.conv1 = nn.Conv2d(in_channels + skip_channels, out_channels, 3, padding=1)
66
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
67
+ self.norm1 = nn.BatchNorm2d(out_channels)
68
+ self.norm2 = nn.BatchNorm2d(out_channels)
69
+ self.time_mlp = nn.Sequential(
70
+ nn.Linear(time_emb_dim, out_channels), nn.SiLU(),
71
+ nn.Linear(out_channels, out_channels)
72
+ )
73
+ self.text_mlp = nn.Sequential(
74
+ nn.Linear(text_emb_dim, out_channels), nn.SiLU(),
75
+ nn.Linear(out_channels, out_channels)
76
+ )
77
+
78
+ def forward(self, x, skip, t_emb, text_emb):
79
+ x = self.up(x)
80
+ x = torch.cat([x, skip], dim=1)
81
+ h = self.conv1(x)
82
+ h = self.norm1(h)
83
+ t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
84
+ txt = self.text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1)
85
+ h = h + t + txt
86
+ h = F.relu(h)
87
+ h = self.conv2(h)
88
+ h = self.norm2(h)
89
+ return F.relu(h)
90
+
91
+
92
+ class DiffusionUNet(nn.Module):
93
+ def __init__(self, vocab_size, image_channels=3, base_channels=64, time_emb_dim=256, text_emb_dim=512):
94
+ super().__init__()
95
+ self.text_encoder = TextEncoder(vocab_size, embed_dim=256, hidden_dim=text_emb_dim)
96
+ self.time_mlp = nn.Sequential(
97
+ nn.Linear(1, time_emb_dim), nn.SiLU(),
98
+ nn.Linear(time_emb_dim, time_emb_dim), nn.SiLU(),
99
+ nn.Linear(time_emb_dim, time_emb_dim)
100
+ )
101
+ self.init_conv = nn.Conv2d(image_channels, base_channels, 3, padding=1)
102
+ self.down1 = DownBlock(base_channels, base_channels, time_emb_dim, text_emb_dim)
103
+ self.down2 = DownBlock(base_channels, base_channels * 2, time_emb_dim, text_emb_dim)
104
+ self.bottleneck_conv1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1)
105
+ self.bottleneck_conv2 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1)
106
+ self.bottleneck_norm1 = nn.BatchNorm2d(base_channels * 2)
107
+ self.bottleneck_norm2 = nn.BatchNorm2d(base_channels * 2)
108
+ self.bottleneck_time_mlp = nn.Sequential(
109
+ nn.Linear(time_emb_dim, base_channels * 2), nn.SiLU(),
110
+ nn.Linear(base_channels * 2, base_channels * 2)
111
+ )
112
+ self.bottleneck_text_mlp = nn.Sequential(
113
+ nn.Linear(text_emb_dim, base_channels * 2), nn.SiLU(),
114
+ nn.Linear(base_channels * 2, base_channels * 2)
115
+ )
116
+ self.up1 = UpBlock(base_channels * 2, base_channels * 2, base_channels, time_emb_dim, text_emb_dim)
117
+ self.up2 = UpBlock(base_channels, base_channels, base_channels, time_emb_dim, text_emb_dim)
118
+ self.out_conv = nn.Conv2d(base_channels, image_channels, 1)
119
+
120
+ def forward(self, x, timesteps, text_tokens):
121
+ text_emb = self.text_encoder(text_tokens)
122
+ t_emb = self.time_mlp(timesteps.unsqueeze(-1).float())
123
+ x1 = self.init_conv(x)
124
+ x2, x2_pooled = self.down1(x1, t_emb, text_emb)
125
+ x3, x3_pooled = self.down2(x2_pooled, t_emb, text_emb)
126
+ h = self.bottleneck_conv1(x3_pooled)
127
+ h = self.bottleneck_norm1(h)
128
+ t = self.bottleneck_time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
129
+ txt = self.bottleneck_text_mlp(text_emb).unsqueeze(-1).unsqueeze(-1)
130
+ h = h + t + txt
131
+ h = F.relu(h)
132
+ h = self.bottleneck_conv2(h)
133
+ h = self.bottleneck_norm2(h)
134
+ bottleneck = F.relu(h)
135
+ d1 = self.up1(bottleneck, x3, t_emb, text_emb)
136
+ d2 = self.up2(d1, x2, t_emb, text_emb)
137
+ return self.out_conv(d2)
138
+
139
+
140
  class Diffusion:
141
  def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
142
  self.timesteps = timesteps
 
146
  self.alpha_bars = torch.cumprod(self.alphas, dim=0)
147
 
148
  @torch.no_grad()
149
+ def sample(self, model, text_tokens, image_size=64, steps=None):
150
  model.eval()
151
  if steps is None:
152
  steps = self.timesteps
153
 
154
+ x = torch.randn(1, 3, image_size, image_size).to(self.device)
155
+
156
  for t in reversed(range(steps)):
157
  t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
158
+ predicted_noise = model(x, t_batch, text_tokens)
159
 
160
  alpha = self.alphas[t]
161
  alpha_bar = self.alpha_bars[t]
 
173
  return x
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  # Global variables
177
  model = None
178
  device = None
179
+ vocab_data = None
180
+
181
+ def download_file(url, filename):
182
+ """Download with progress tracking"""
183
+ if not os.path.exists(filename):
184
+ print(f"Downloading {filename}...")
185
+ urllib.request.urlretrieve(url, filename)
186
+ print(f"Downloaded {filename}")
187
+ else:
188
+ print(f"{filename} already exists")
189
 
190
  # Download and load model
191
  def initialize_model():
192
+ global model, device, vocab_data
193
 
194
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
195
 
196
+ # Download model and vocab
197
  model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
198
  model_path = "newest.pth"
199
 
200
+ download_file(model_url, model_path)
 
201
 
202
+ # Load checkpoint
203
  checkpoint = torch.load(model_path, map_location=device)
204
 
205
+ # Get vocab info from checkpoint
206
+ vocab_data = {
207
+ 'vocab': checkpoint['vocab'],
208
+ 'word_to_idx': checkpoint['word_to_idx'],
209
+ 'vocab_size': checkpoint['vocab_size']
210
+ }
211
+
212
+ # Create model with correct vocab size
213
+ model = DiffusionUNet(
214
+ vocab_size=vocab_data['vocab_size'],
215
+ image_channels=3,
216
+ base_channels=64
217
+ ).to(device)
218
+
219
+ # Load state dict
220
  model.load_state_dict(checkpoint['model_state_dict'])
221
  model.eval()
222
 
223
+ print(f"Model loaded successfully! Vocab size: {vocab_data['vocab_size']}")
224
+ return "✅ Model loaded successfully! You can now generate images."
225
+
226
+ def tokenize_text(text, max_len=20):
227
+ """Tokenize text input for the model"""
228
+ words = [w.strip('.,!?"\'') for w in text.lower().split()]
229
+ tokens = words[:max_len]
230
+ indices = [vocab_data['word_to_idx'].get(token, vocab_data['word_to_idx'].get('<UNK>', 1)) for token in tokens]
231
+ while len(indices) < max_len:
232
+ indices.append(0) # PAD token
233
+ return torch.tensor(indices).unsqueeze(0).to(device)
234
 
235
  # Generate image
236
+ def generate_image(prompt):
237
+ global model, device, vocab_data
238
 
239
+ if model is None or vocab_data is None:
240
  return None
241
 
242
+ diffusion = Diffusion(timesteps=500, device=device) # Use 500 timesteps like training
243
 
244
  with torch.no_grad():
245
+ text_tokens = tokenize_text(prompt)
246
+ generated = diffusion.sample(model, text_tokens, image_size=64, steps=500)
247
 
248
  # Convert to image
249
  image = generated.cpu().squeeze(0)
 
255
  return Image.fromarray(image)
256
 
257
  # Create interface
258
+ with gr.Blocks(title="RandomDiffusion Text-to-Image") as demo:
259
  gr.Markdown("# 🎨 RandomDiffusion")
260
+ gr.Markdown("Text-to-Image generation using diffusion model")
261
 
262
  status = gr.Textbox(label="Status", value="Loading model...", interactive=False)
263
 
264
  with gr.Row():
265
+ prompt_input = gr.Textbox(
266
+ label="Prompt",
267
+ value="a beautiful landscape",
268
+ placeholder="Enter your text prompt here..."
269
+ )
270
+
271
+ with gr.Row():
272
+ generate_btn = gr.Button("Generate Image", variant="primary")
273
 
274
  output_image = gr.Image(label="Generated Image", type="pil")
275
 
276
+ # Load model on startup
277
  demo.load(
278
  lambda: initialize_model(),
279
  outputs=[status]
280
  )
281
 
282
+ # Generate on button click
283
  generate_btn.click(
284
  generate_image,
285
+ inputs=[prompt_input],
286
  outputs=[output_image]
287
  )
288