Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
+
import requests
|
8 |
+
from bs4 import BeautifulSoup
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import threading
|
12 |
+
from PIL import Image
|
13 |
+
import io
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
# ======================
|
17 |
+
# Configuration
|
18 |
+
# ======================
|
19 |
+
CONFIG = {
|
20 |
+
"scraping": {
|
21 |
+
"search_url": "https://www.pexels.com/search/{query}/",
|
22 |
+
"headers": {
|
23 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
24 |
+
},
|
25 |
+
"max_images": 100,
|
26 |
+
"scrape_time": 10 # 3 hours in seconds (simulated for testing)
|
27 |
+
},
|
28 |
+
"training": {
|
29 |
+
"batch_size": 4,
|
30 |
+
"epochs": 10,
|
31 |
+
"lr": 0.0002,
|
32 |
+
"latent_dim": 100,
|
33 |
+
"img_size": 64,
|
34 |
+
"num_workers": 0
|
35 |
+
},
|
36 |
+
"paths": {
|
37 |
+
"dataset_dir": "scraped_data",
|
38 |
+
"model_save": "text2img_model.pth"
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
# ======================
|
43 |
+
# Web Scraping Module
|
44 |
+
# ======================
|
45 |
+
class WebScraper:
|
46 |
+
def __init__(self):
|
47 |
+
self.stop_event = threading.Event()
|
48 |
+
self.scraped_data = []
|
49 |
+
|
50 |
+
def scrape_images(self, query):
|
51 |
+
search_url = CONFIG["scraping"]["search_url"].format(query=query)
|
52 |
+
try:
|
53 |
+
response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
|
54 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
55 |
+
|
56 |
+
# Extract image URLs (example selector - needs adjustment for actual site)
|
57 |
+
img_tags = soup.find_all('img', {'class': 'photo-item__img'})
|
58 |
+
for img in img_tags[:CONFIG["scraping"]["max_images"]]:
|
59 |
+
if self.stop_event.is_set():
|
60 |
+
break
|
61 |
+
img_url = img['src']
|
62 |
+
try:
|
63 |
+
img_data = requests.get(img_url).content
|
64 |
+
img_name = f"{int(time.time())}.jpg"
|
65 |
+
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
|
66 |
+
|
67 |
+
with open(img_path, 'wb') as f:
|
68 |
+
f.write(img_data)
|
69 |
+
|
70 |
+
# Store text-image pair (text = query)
|
71 |
+
self.scraped_data.append({"text": query, "image": img_path})
|
72 |
+
except Exception as e:
|
73 |
+
print(f"Error downloading image: {e}")
|
74 |
+
except Exception as e:
|
75 |
+
print(f"Scraping error: {e}")
|
76 |
+
|
77 |
+
def start_scraping(self, query):
|
78 |
+
self.stop_event.clear()
|
79 |
+
if not os.path.exists(CONFIG["paths"]["dataset_dir"]):
|
80 |
+
os.makedirs(CONFIG["paths"]["dataset_dir"])
|
81 |
+
|
82 |
+
thread = threading.Thread(target=self.scrape_images, args=(query,))
|
83 |
+
thread.start()
|
84 |
+
return "Scraping started..."
|
85 |
+
|
86 |
+
# ======================
|
87 |
+
# Dataset and Models
|
88 |
+
# ======================
|
89 |
+
class TextImageDataset(Dataset):
|
90 |
+
def __init__(self, data, transform=None):
|
91 |
+
self.data = data
|
92 |
+
self.transform = transform
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.data)
|
96 |
+
|
97 |
+
def __getitem__(self, idx):
|
98 |
+
item = self.data[idx]
|
99 |
+
image = Image.open(item["image"]).convert('RGB')
|
100 |
+
|
101 |
+
if self.transform:
|
102 |
+
image = self.transform(image)
|
103 |
+
|
104 |
+
return {"text": item["text"], "image": image}
|
105 |
+
|
106 |
+
# Simplified Text-to-Image Generator
|
107 |
+
class TextConditionedGenerator(nn.Module):
|
108 |
+
def __init__(self):
|
109 |
+
super().__init__()
|
110 |
+
self.text_embedding = nn.Embedding(1000, 128) # Simplified text embedding
|
111 |
+
self.model = nn.Sequential(
|
112 |
+
nn.Linear(128 + CONFIG["training"]["latent_dim"], 256),
|
113 |
+
nn.LeakyReLU(0.2),
|
114 |
+
nn.Linear(256, 512),
|
115 |
+
nn.BatchNorm1d(512),
|
116 |
+
nn.LeakyReLU(0.2),
|
117 |
+
nn.Linear(512, 3 * CONFIG["training"]["img_size"] ** 2),
|
118 |
+
nn.Tanh()
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(self, text, noise):
|
122 |
+
text_emb = self.text_embedding(text)
|
123 |
+
combined = torch.cat([text_emb, noise], 1)
|
124 |
+
img = self.model(combined)
|
125 |
+
return img.view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"])
|
126 |
+
|
127 |
+
# ======================
|
128 |
+
# Training Utilities
|
129 |
+
# ======================
|
130 |
+
def train_model(scraper, progress=gr.Progress()):
|
131 |
+
dataset = TextImageDataset(scraper.scraped_data)
|
132 |
+
dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True)
|
133 |
+
|
134 |
+
generator = TextConditionedGenerator()
|
135 |
+
discriminator = nn.Sequential(
|
136 |
+
nn.Linear(3 * CONFIG["training"]["img_size"] ** 2, 512),
|
137 |
+
nn.LeakyReLU(0.2),
|
138 |
+
nn.Linear(512, 1),
|
139 |
+
nn.Sigmoid()
|
140 |
+
)
|
141 |
+
|
142 |
+
optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"])
|
143 |
+
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
|
144 |
+
criterion = nn.BCELoss()
|
145 |
+
|
146 |
+
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"):
|
147 |
+
for i, batch in enumerate(dataloader):
|
148 |
+
# Train discriminator
|
149 |
+
real_imgs = batch["image"]
|
150 |
+
real_labels = torch.ones(real_imgs.size(0), 1)
|
151 |
+
|
152 |
+
noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"])
|
153 |
+
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
|
154 |
+
fake_labels = torch.zeros(real_imgs.size(0), 1)
|
155 |
+
|
156 |
+
optimizer_D.zero_grad()
|
157 |
+
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels)
|
158 |
+
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), fake_labels)
|
159 |
+
d_loss = real_loss + fake_loss
|
160 |
+
d_loss.backward()
|
161 |
+
optimizer_D.step()
|
162 |
+
|
163 |
+
# Train generator
|
164 |
+
optimizer_G.zero_grad()
|
165 |
+
validity = discriminator(fake_imgs.view(-1, 3*64**2))
|
166 |
+
g_loss = criterion(validity, torch.ones_like(validity))
|
167 |
+
g_loss.backward()
|
168 |
+
optimizer_G.step()
|
169 |
+
|
170 |
+
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
|
171 |
+
return "Training completed!"
|
172 |
+
|
173 |
+
# ======================
|
174 |
+
# Inference Modules
|
175 |
+
# ======================
|
176 |
+
class ModelRunner:
|
177 |
+
def __init__(self):
|
178 |
+
self.pretrained_pipe = None
|
179 |
+
self.custom_model = None
|
180 |
+
|
181 |
+
def load_pretrained(self):
|
182 |
+
if self.pretrained_pipe is None:
|
183 |
+
self.pretrained_pipe = DiffusionPipeline.from_pretrained(
|
184 |
+
"stabilityai/stable-diffusion-xl-base-1.0"
|
185 |
+
)
|
186 |
+
return self.pretrained_pipe
|
187 |
+
|
188 |
+
def load_custom(self):
|
189 |
+
if self.custom_model is None:
|
190 |
+
model = TextConditionedGenerator()
|
191 |
+
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"]))
|
192 |
+
self.custom_model = model
|
193 |
+
return self.custom_model
|
194 |
+
|
195 |
+
# ======================
|
196 |
+
# Gradio Interface
|
197 |
+
# ======================
|
198 |
+
with gr.Blocks() as app:
|
199 |
+
scraper = WebScraper()
|
200 |
+
model_runner = ModelRunner()
|
201 |
+
|
202 |
+
with gr.Row():
|
203 |
+
with gr.Column():
|
204 |
+
query_input = gr.Textbox(label="Search Query")
|
205 |
+
scrape_btn = gr.Button("Start Scraping")
|
206 |
+
scrape_status = gr.Textbox(label="Scraping Status")
|
207 |
+
|
208 |
+
train_btn = gr.Button("Start Training")
|
209 |
+
training_status = gr.Textbox(label="Training Status")
|
210 |
+
|
211 |
+
with gr.Column():
|
212 |
+
prompt_input = gr.Textbox(label="Generation Prompt")
|
213 |
+
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type")
|
214 |
+
generate_btn = gr.Button("Generate Image")
|
215 |
+
output_image = gr.Image(label="Generated Image")
|
216 |
+
|
217 |
+
# Event Handlers
|
218 |
+
scrape_btn.click(
|
219 |
+
fn=scraper.start_scraping,
|
220 |
+
inputs=query_input,
|
221 |
+
outputs=scrape_status
|
222 |
+
)
|
223 |
+
|
224 |
+
train_btn.click(
|
225 |
+
fn=train_model,
|
226 |
+
inputs=[scraper],
|
227 |
+
outputs=training_status
|
228 |
+
)
|
229 |
+
|
230 |
+
generate_btn.click(
|
231 |
+
fn=lambda prompt, model_type: generate_image(prompt, model_type, model_runner),
|
232 |
+
inputs=[prompt_input, model_choice],
|
233 |
+
outputs=output_image
|
234 |
+
)
|
235 |
+
|
236 |
+
def generate_image(prompt, model_type, runner):
|
237 |
+
if model_type == "Pretrained":
|
238 |
+
pipe = runner.load_pretrained()
|
239 |
+
image = pipe(prompt).images[0]
|
240 |
+
else:
|
241 |
+
model = runner.load_custom()
|
242 |
+
# Simplified generation process
|
243 |
+
noise = torch.randn(1, CONFIG["training"]["latent_dim"])
|
244 |
+
fake = model(torch.randint(0, 1000, (1,)), noise).detach()
|
245 |
+
image = fake.squeeze().permute(1,2,0).numpy()
|
246 |
+
image = (image + 1) / 2 # Scale to [0,1]
|
247 |
+
|
248 |
+
return Image.fromarray((image * 255).astype(np.uint8))
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
app.launch()
|