gaur3009 commited on
Commit
07f6f3d
·
verified ·
1 Parent(s): 35eba1d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from diffusers import DiffusionPipeline
7
+ import requests
8
+ from bs4 import BeautifulSoup
9
+ import os
10
+ import time
11
+ import threading
12
+ from PIL import Image
13
+ import io
14
+ import numpy as np
15
+
16
+ # ======================
17
+ # Configuration
18
+ # ======================
19
+ CONFIG = {
20
+ "scraping": {
21
+ "search_url": "https://www.pexels.com/search/{query}/",
22
+ "headers": {
23
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
24
+ },
25
+ "max_images": 100,
26
+ "scrape_time": 10 # 3 hours in seconds (simulated for testing)
27
+ },
28
+ "training": {
29
+ "batch_size": 4,
30
+ "epochs": 10,
31
+ "lr": 0.0002,
32
+ "latent_dim": 100,
33
+ "img_size": 64,
34
+ "num_workers": 0
35
+ },
36
+ "paths": {
37
+ "dataset_dir": "scraped_data",
38
+ "model_save": "text2img_model.pth"
39
+ }
40
+ }
41
+
42
+ # ======================
43
+ # Web Scraping Module
44
+ # ======================
45
+ class WebScraper:
46
+ def __init__(self):
47
+ self.stop_event = threading.Event()
48
+ self.scraped_data = []
49
+
50
+ def scrape_images(self, query):
51
+ search_url = CONFIG["scraping"]["search_url"].format(query=query)
52
+ try:
53
+ response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
54
+ soup = BeautifulSoup(response.content, 'html.parser')
55
+
56
+ # Extract image URLs (example selector - needs adjustment for actual site)
57
+ img_tags = soup.find_all('img', {'class': 'photo-item__img'})
58
+ for img in img_tags[:CONFIG["scraping"]["max_images"]]:
59
+ if self.stop_event.is_set():
60
+ break
61
+ img_url = img['src']
62
+ try:
63
+ img_data = requests.get(img_url).content
64
+ img_name = f"{int(time.time())}.jpg"
65
+ img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
66
+
67
+ with open(img_path, 'wb') as f:
68
+ f.write(img_data)
69
+
70
+ # Store text-image pair (text = query)
71
+ self.scraped_data.append({"text": query, "image": img_path})
72
+ except Exception as e:
73
+ print(f"Error downloading image: {e}")
74
+ except Exception as e:
75
+ print(f"Scraping error: {e}")
76
+
77
+ def start_scraping(self, query):
78
+ self.stop_event.clear()
79
+ if not os.path.exists(CONFIG["paths"]["dataset_dir"]):
80
+ os.makedirs(CONFIG["paths"]["dataset_dir"])
81
+
82
+ thread = threading.Thread(target=self.scrape_images, args=(query,))
83
+ thread.start()
84
+ return "Scraping started..."
85
+
86
+ # ======================
87
+ # Dataset and Models
88
+ # ======================
89
+ class TextImageDataset(Dataset):
90
+ def __init__(self, data, transform=None):
91
+ self.data = data
92
+ self.transform = transform
93
+
94
+ def __len__(self):
95
+ return len(self.data)
96
+
97
+ def __getitem__(self, idx):
98
+ item = self.data[idx]
99
+ image = Image.open(item["image"]).convert('RGB')
100
+
101
+ if self.transform:
102
+ image = self.transform(image)
103
+
104
+ return {"text": item["text"], "image": image}
105
+
106
+ # Simplified Text-to-Image Generator
107
+ class TextConditionedGenerator(nn.Module):
108
+ def __init__(self):
109
+ super().__init__()
110
+ self.text_embedding = nn.Embedding(1000, 128) # Simplified text embedding
111
+ self.model = nn.Sequential(
112
+ nn.Linear(128 + CONFIG["training"]["latent_dim"], 256),
113
+ nn.LeakyReLU(0.2),
114
+ nn.Linear(256, 512),
115
+ nn.BatchNorm1d(512),
116
+ nn.LeakyReLU(0.2),
117
+ nn.Linear(512, 3 * CONFIG["training"]["img_size"] ** 2),
118
+ nn.Tanh()
119
+ )
120
+
121
+ def forward(self, text, noise):
122
+ text_emb = self.text_embedding(text)
123
+ combined = torch.cat([text_emb, noise], 1)
124
+ img = self.model(combined)
125
+ return img.view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"])
126
+
127
+ # ======================
128
+ # Training Utilities
129
+ # ======================
130
+ def train_model(scraper, progress=gr.Progress()):
131
+ dataset = TextImageDataset(scraper.scraped_data)
132
+ dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True)
133
+
134
+ generator = TextConditionedGenerator()
135
+ discriminator = nn.Sequential(
136
+ nn.Linear(3 * CONFIG["training"]["img_size"] ** 2, 512),
137
+ nn.LeakyReLU(0.2),
138
+ nn.Linear(512, 1),
139
+ nn.Sigmoid()
140
+ )
141
+
142
+ optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"])
143
+ optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
144
+ criterion = nn.BCELoss()
145
+
146
+ for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"):
147
+ for i, batch in enumerate(dataloader):
148
+ # Train discriminator
149
+ real_imgs = batch["image"]
150
+ real_labels = torch.ones(real_imgs.size(0), 1)
151
+
152
+ noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"])
153
+ fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
154
+ fake_labels = torch.zeros(real_imgs.size(0), 1)
155
+
156
+ optimizer_D.zero_grad()
157
+ real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels)
158
+ fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), fake_labels)
159
+ d_loss = real_loss + fake_loss
160
+ d_loss.backward()
161
+ optimizer_D.step()
162
+
163
+ # Train generator
164
+ optimizer_G.zero_grad()
165
+ validity = discriminator(fake_imgs.view(-1, 3*64**2))
166
+ g_loss = criterion(validity, torch.ones_like(validity))
167
+ g_loss.backward()
168
+ optimizer_G.step()
169
+
170
+ torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
171
+ return "Training completed!"
172
+
173
+ # ======================
174
+ # Inference Modules
175
+ # ======================
176
+ class ModelRunner:
177
+ def __init__(self):
178
+ self.pretrained_pipe = None
179
+ self.custom_model = None
180
+
181
+ def load_pretrained(self):
182
+ if self.pretrained_pipe is None:
183
+ self.pretrained_pipe = DiffusionPipeline.from_pretrained(
184
+ "stabilityai/stable-diffusion-xl-base-1.0"
185
+ )
186
+ return self.pretrained_pipe
187
+
188
+ def load_custom(self):
189
+ if self.custom_model is None:
190
+ model = TextConditionedGenerator()
191
+ model.load_state_dict(torch.load(CONFIG["paths"]["model_save"]))
192
+ self.custom_model = model
193
+ return self.custom_model
194
+
195
+ # ======================
196
+ # Gradio Interface
197
+ # ======================
198
+ with gr.Blocks() as app:
199
+ scraper = WebScraper()
200
+ model_runner = ModelRunner()
201
+
202
+ with gr.Row():
203
+ with gr.Column():
204
+ query_input = gr.Textbox(label="Search Query")
205
+ scrape_btn = gr.Button("Start Scraping")
206
+ scrape_status = gr.Textbox(label="Scraping Status")
207
+
208
+ train_btn = gr.Button("Start Training")
209
+ training_status = gr.Textbox(label="Training Status")
210
+
211
+ with gr.Column():
212
+ prompt_input = gr.Textbox(label="Generation Prompt")
213
+ model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type")
214
+ generate_btn = gr.Button("Generate Image")
215
+ output_image = gr.Image(label="Generated Image")
216
+
217
+ # Event Handlers
218
+ scrape_btn.click(
219
+ fn=scraper.start_scraping,
220
+ inputs=query_input,
221
+ outputs=scrape_status
222
+ )
223
+
224
+ train_btn.click(
225
+ fn=train_model,
226
+ inputs=[scraper],
227
+ outputs=training_status
228
+ )
229
+
230
+ generate_btn.click(
231
+ fn=lambda prompt, model_type: generate_image(prompt, model_type, model_runner),
232
+ inputs=[prompt_input, model_choice],
233
+ outputs=output_image
234
+ )
235
+
236
+ def generate_image(prompt, model_type, runner):
237
+ if model_type == "Pretrained":
238
+ pipe = runner.load_pretrained()
239
+ image = pipe(prompt).images[0]
240
+ else:
241
+ model = runner.load_custom()
242
+ # Simplified generation process
243
+ noise = torch.randn(1, CONFIG["training"]["latent_dim"])
244
+ fake = model(torch.randint(0, 1000, (1,)), noise).detach()
245
+ image = fake.squeeze().permute(1,2,0).numpy()
246
+ image = (image + 1) / 2 # Scale to [0,1]
247
+
248
+ return Image.fromarray((image * 255).astype(np.uint8))
249
+
250
+ if __name__ == "__main__":
251
+ app.launch()