lazerkat commited on
Commit
1748b4f
·
verified ·
1 Parent(s): 7531a96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -52
app.py CHANGED
@@ -1,105 +1,178 @@
1
  import gradio as gr
2
- import json
3
  import os
4
  import urllib.request
5
- from pathlib import Path
6
-
7
  import torch
 
 
8
  from PIL import Image
9
  import numpy as np
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Global variables
12
  model = None
13
- checkpoint = None
14
  device = None
15
 
16
- # Download and load the model
17
  def initialize_model():
18
- global model, checkpoint, device
19
 
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
 
22
  model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
23
  model_path = "newest.pth"
24
 
25
- # Download if not already present
26
  if not os.path.exists(model_path):
27
- gr.Info("Downloading model...")
28
  urllib.request.urlretrieve(model_url, model_path)
29
 
30
- # Load checkpoint
31
  checkpoint = torch.load(model_path, map_location=device)
32
 
33
- # Recreate the model architecture
34
- from train import DiffusionUNet # Import directly from training script
35
- model = DiffusionUNet(vocab_size=checkpoint['vocab_size']).to(device)
36
  model.load_state_dict(checkpoint['model_state_dict'])
37
  model.eval()
38
 
39
- return "Model loaded successfully!"
40
 
41
- # Generate image from prompt
42
- def generate_image(prompt):
43
- global model, checkpoint, device
44
 
45
  if model is None:
46
- return None, "Model not loaded yet. Please wait for initialization."
47
-
48
- # Tokenize prompt using the saved vocab
49
- vocab_data = checkpoint['word_to_idx']
50
- max_len = 20
51
- words = [w.strip('.,!?"\'') for w in prompt.lower().split()][:max_len]
52
- indices = [vocab_data.get(w, 1) for w in words]
53
- indices += [0] * (max_len - len(indices))
54
- text_tokens = torch.tensor(indices).unsqueeze(0).to(device)
55
 
56
- # Diffusion sampling
57
- from train import Diffusion
58
- diffusion = Diffusion(timesteps=500, device=device)
59
 
60
  with torch.no_grad():
61
- generated = diffusion.sample(model, text_tokens, image_size=64, batch_size=1)
 
62
 
63
- # Convert to PIL image
64
  image = generated.cpu().squeeze(0)
65
  image = (image + 1) / 2
66
  image = image.clamp(0, 1)
67
  image = image.permute(1, 2, 0).numpy()
68
  image = (image * 255).astype(np.uint8)
69
- img = Image.fromarray(image)
70
 
71
- return img, f"Generated image for: '{prompt}'"
72
 
73
- # Create the interface
74
- with gr.Blocks(title="RandomDiffusion", theme=gr.themes.Soft()) as demo:
75
- gr.Markdown("# RandomDiffusion")
76
- gr.Markdown("Text-to-Image Diffusion Model")
77
 
78
- # Model status
79
- status = gr.Textbox(label="Model Status", value="Initializing...", interactive=False)
80
 
81
- # Image generation
82
  with gr.Row():
83
- with gr.Column():
84
- prompt = gr.Textbox(label="Enter Prompt", placeholder="a beautiful landscape")
85
- generate_btn = gr.Button("Generate")
86
- with gr.Column():
87
- output_image = gr.Image(label="Generated Image", type="pil")
88
- result_text = gr.Textbox(label="Result")
89
-
90
- # Load model on startup
91
  demo.load(
92
  lambda: initialize_model(),
93
- inputs=[],
94
  outputs=[status]
95
  )
96
 
97
- # Generate on button click
98
  generate_btn.click(
99
  generate_image,
100
- inputs=[prompt],
101
- outputs=[output_image, result_text]
102
  )
103
 
104
  if __name__ == "__main__":
105
- demo.launch(share=True)
 
1
  import gradio as gr
 
2
  import os
3
  import urllib.request
 
 
4
  import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
  from PIL import Image
8
  import numpy as np
9
 
10
+ # ============================================================================
11
+ # DIFFUSION Model Architecture
12
+ # ============================================================================
13
+
14
+ class Diffusion:
15
+ def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
16
+ self.timesteps = timesteps
17
+ self.device = device
18
+ self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
19
+ self.alphas = 1 - self.betas
20
+ self.alpha_bars = torch.cumprod(self.alphas, dim=0)
21
+
22
+ @torch.no_grad()
23
+ def sample(self, model, x, steps=None):
24
+ model.eval()
25
+ if steps is None:
26
+ steps = self.timesteps
27
+
28
+ for t in reversed(range(steps)):
29
+ t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
30
+ predicted_noise = model(x, t_batch)
31
+
32
+ alpha = self.alphas[t]
33
+ alpha_bar = self.alpha_bars[t]
34
+ beta = self.betas[t]
35
+
36
+ if t > 0:
37
+ noise = torch.randn_like(x)
38
+ else:
39
+ noise = 0
40
+
41
+ x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
42
+ x = x + torch.sqrt(beta) * noise
43
+
44
+ model.train()
45
+ return x
46
+
47
+
48
+ class UNet(nn.Module):
49
+ def __init__(self, in_channels=3, out_channels=3):
50
+ super().__init__()
51
+
52
+ # Encoder
53
+ self.enc1 = self.conv_block(in_channels, 64)
54
+ self.enc2 = self.conv_block(64, 128)
55
+ self.enc3 = self.conv_block(128, 256)
56
+
57
+ # Bottleneck
58
+ self.bottleneck = self.conv_block(256, 512)
59
+
60
+ # Decoder
61
+ self.dec3 = self.conv_block(512 + 256, 256)
62
+ self.dec2 = self.conv_block(256 + 128, 128)
63
+ self.dec1 = self.conv_block(128 + 64, 64)
64
+
65
+ # Time embedding
66
+ self.time_embed = nn.Sequential(
67
+ nn.Linear(1, 128),
68
+ nn.ReLU(),
69
+ nn.Linear(128, 128)
70
+ )
71
+
72
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
73
+ self.final = nn.Conv2d(64, out_channels, 1)
74
+
75
+ self.pool = nn.MaxPool2d(2)
76
+
77
+ def conv_block(self, in_ch, out_ch):
78
+ return nn.Sequential(
79
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
80
+ nn.BatchNorm2d(out_ch),
81
+ nn.ReLU(inplace=True),
82
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
83
+ nn.BatchNorm2d(out_ch),
84
+ nn.ReLU(inplace=True)
85
+ )
86
+
87
+ def forward(self, x, t):
88
+ # Time embedding
89
+ t_embed = self.time_embed(t.float().unsqueeze(-1))
90
+ t_embed = t_embed.unsqueeze(-1).unsqueeze(-1)
91
+
92
+ # Encoder
93
+ e1 = self.enc1(x)
94
+ e2 = self.enc2(self.pool(e1))
95
+ e3 = self.enc3(self.pool(e2))
96
+
97
+ # Bottleneck
98
+ b = self.bottleneck(self.pool(e3))
99
+ b = b + t_embed.repeat(1, 1, b.shape[2], b.shape[3]) if b.shape[1] == t_embed.shape[1] else b
100
+
101
+ # Decoder
102
+ d3 = self.dec3(torch.cat([self.up(b), e3], dim=1))
103
+ d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1))
104
+ d1 = self.dec1(torch.cat([self.up(d2), e1], dim=1))
105
+
106
+ return self.final(d1)
107
+
108
+
109
  # Global variables
110
  model = None
 
111
  device = None
112
 
113
+ # Download and load model
114
  def initialize_model():
115
+ global model, device
116
 
117
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
118
 
119
  model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
120
  model_path = "newest.pth"
121
 
 
122
  if not os.path.exists(model_path):
 
123
  urllib.request.urlretrieve(model_url, model_path)
124
 
 
125
  checkpoint = torch.load(model_path, map_location=device)
126
 
127
+ model = UNet().to(device)
 
 
128
  model.load_state_dict(checkpoint['model_state_dict'])
129
  model.eval()
130
 
131
+ return "Model loaded successfully!"
132
 
133
+ # Generate image
134
+ def generate_image():
135
+ global model, device
136
 
137
  if model is None:
138
+ return None
 
 
 
 
 
 
 
 
139
 
140
+ diffusion = Diffusion(timesteps=1000, device=device)
 
 
141
 
142
  with torch.no_grad():
143
+ noise = torch.randn(1, 3, 64, 64).to(device)
144
+ generated = diffusion.sample(model, noise, steps=100)
145
 
146
+ # Convert to image
147
  image = generated.cpu().squeeze(0)
148
  image = (image + 1) / 2
149
  image = image.clamp(0, 1)
150
  image = image.permute(1, 2, 0).numpy()
151
  image = (image * 255).astype(np.uint8)
 
152
 
153
+ return Image.fromarray(image)
154
 
155
+ # Create interface
156
+ with gr.Blocks(title="RandomDiffusion") as demo:
157
+ gr.Markdown("# 🎨 RandomDiffusion")
158
+ gr.Markdown("Random image generation using diffusion")
159
 
160
+ status = gr.Textbox(label="Status", value="Loading model...", interactive=False)
 
161
 
 
162
  with gr.Row():
163
+ generate_btn = gr.Button("Generate Random Image", variant="primary")
164
+
165
+ output_image = gr.Image(label="Generated Image", type="pil")
166
+
 
 
 
 
167
  demo.load(
168
  lambda: initialize_model(),
 
169
  outputs=[status]
170
  )
171
 
 
172
  generate_btn.click(
173
  generate_image,
174
+ outputs=[output_image]
 
175
  )
176
 
177
  if __name__ == "__main__":
178
+ demo.launch()