Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from datasets import load_dataset
|
|
7 |
from huggingface_hub import Repository
|
8 |
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
|
9 |
import os
|
|
|
10 |
import gradio as gr
|
11 |
from PIL import Image
|
12 |
import numpy as np
|
@@ -52,10 +53,17 @@ def load_model():
|
|
52 |
return model
|
53 |
|
54 |
|
|
|
|
|
55 |
|
56 |
-
# Dataset class remains the same
|
57 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
58 |
-
def __init__(self, ds, transform):
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
60 |
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
61 |
assert len(self.originals) == len(self.targets)
|
@@ -67,8 +75,59 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
67 |
return len(self.originals)
|
68 |
|
69 |
def __getitem__(self, idx):
|
|
|
70 |
original_img = self.originals[idx]['image']
|
71 |
target_img = self.targets[idx]['image']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
original = original_img.convert('RGB')
|
73 |
target = target_img.convert('RGB')
|
74 |
return self.transform(original), self.transform(target)
|
@@ -188,7 +247,7 @@ def prepare_input(image, device='cpu'):
|
|
188 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
189 |
return input_tensor
|
190 |
|
191 |
-
def run_inference(image):
|
192 |
"""Run inference on a single image"""
|
193 |
global global_model
|
194 |
if global_model is None:
|
@@ -219,41 +278,67 @@ def train_model(epochs):
|
|
219 |
transforms.ToTensor(),
|
220 |
])
|
221 |
|
222 |
-
dataset
|
|
|
223 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
224 |
|
225 |
model = global_model
|
226 |
-
criterion = nn.L1Loss()
|
227 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
228 |
output_text = []
|
229 |
|
230 |
for epoch in range(epochs):
|
231 |
model.train()
|
232 |
-
for i, (original, target) in enumerate(dataloader):
|
|
|
233 |
original, target = original.to(device), target.to(device)
|
|
|
|
|
|
|
234 |
optimizer.zero_grad()
|
|
|
|
|
235 |
output = model(target)
|
236 |
-
|
237 |
-
loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
optimizer.step()
|
239 |
|
240 |
if i % 10 == 0:
|
241 |
-
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {
|
242 |
print(status)
|
243 |
output_text.append(status)
|
244 |
|
|
|
245 |
to_hub(model)
|
246 |
|
247 |
-
global_model = model
|
248 |
return model, "\n".join(output_text)
|
249 |
|
250 |
-
|
251 |
def gradio_train(epochs):
|
252 |
"""Gradio training interface function"""
|
253 |
model, training_log = train_model(int(epochs))
|
254 |
to_hub(model)
|
255 |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
|
256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
def gradio_inference(input_image):
|
258 |
"""Gradio inference interface function"""
|
259 |
return input_image, run_inference(input_image)
|
|
|
7 |
from huggingface_hub import Repository
|
8 |
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
|
9 |
import os
|
10 |
+
import pandas as pd
|
11 |
import gradio as gr
|
12 |
from PIL import Image
|
13 |
import numpy as np
|
|
|
53 |
return model
|
54 |
|
55 |
|
56 |
+
import os
|
57 |
+
import pandas as pd
|
58 |
|
|
|
59 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
60 |
+
def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
|
61 |
+
if not os.path.exists(csv_path):
|
62 |
+
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
|
63 |
+
|
64 |
+
self.data = pd.read_csv(csv_path)
|
65 |
+
self.clip_tokenizer = clip_tokenizer
|
66 |
+
|
67 |
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
68 |
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
69 |
assert len(self.originals) == len(self.targets)
|
|
|
75 |
return len(self.originals)
|
76 |
|
77 |
def __getitem__(self, idx):
|
78 |
+
# Get original and target images
|
79 |
original_img = self.originals[idx]['image']
|
80 |
target_img = self.targets[idx]['image']
|
81 |
+
|
82 |
+
# Convert PIL images
|
83 |
+
original = original_img.convert('RGB')
|
84 |
+
target = target_img.convert('RGB')
|
85 |
+
|
86 |
+
# Extract the filename from the original image's path (assuming it has a 'filename' field or path)
|
87 |
+
original_img_path = self.originals[idx]['image'].filename # Assuming it has this attribute
|
88 |
+
original_img_filename = os.path.basename(original_img_path)
|
89 |
+
|
90 |
+
# Match the image filename with the `image_path` column in the CSV
|
91 |
+
matched_row = self.data[self.data['image_path'].str.contains(original_img_filename)]
|
92 |
+
|
93 |
+
if matched_row.empty:
|
94 |
+
raise ValueError(f"No matching entry found in the CSV for image {original_img_filename}")
|
95 |
+
|
96 |
+
# Get the prompts from the matched row
|
97 |
+
original_prompt = matched_row['original_prompt'].values[0]
|
98 |
+
enhanced_prompt = matched_row['enhanced_prompt'].values[0]
|
99 |
+
|
100 |
+
# Tokenize the prompts using CLIP tokenizer
|
101 |
+
original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
102 |
+
enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
103 |
+
|
104 |
+
# Return transformed images and tokenized prompts
|
105 |
+
return self.transform(original), self.transform(target), original_tokens, enhanced_tokens
|
106 |
+
|
107 |
+
|
108 |
+
# Dataset class remains the same
|
109 |
+
class Pix2PixDataset_old(torch.utils.data.Dataset):
|
110 |
+
def __init__(self, ds, transform, csv_path='combined_data.csv'):
|
111 |
+
if not os.path.exists(csv_path):
|
112 |
+
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
|
113 |
+
|
114 |
+
self.data = pd.read_csv(csv_path)
|
115 |
+
self.clip_tokenizer = clip_tokenizer
|
116 |
+
|
117 |
+
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
118 |
+
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
119 |
+
assert len(self.originals) == len(self.targets)
|
120 |
+
print(f"Number of original images: {len(self.originals)}")
|
121 |
+
print(f"Number of target images: {len(self.targets)}")
|
122 |
+
self.transform = transform
|
123 |
+
|
124 |
+
def __len__(self):
|
125 |
+
return len(self.originals)
|
126 |
+
|
127 |
+
def __getitem__(self, idx):
|
128 |
+
original_img = self.originals[idx]['image']
|
129 |
+
# TODO: get original_img file name and match with image_path in self.data....then tokenize the prompts with clip_tokenizer
|
130 |
+
target_img = self.targets[idx]['image']
|
131 |
original = original_img.convert('RGB')
|
132 |
target = target_img.convert('RGB')
|
133 |
return self.transform(original), self.transform(target)
|
|
|
247 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
248 |
return input_tensor
|
249 |
|
250 |
+
def run_inference(image, prompt):
|
251 |
"""Run inference on a single image"""
|
252 |
global global_model
|
253 |
if global_model is None:
|
|
|
278 |
transforms.ToTensor(),
|
279 |
])
|
280 |
|
281 |
+
# Initialize the dataset and dataloader
|
282 |
+
dataset = Pix2PixDataset(ds, transform, clip_tokenizer)
|
283 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
284 |
|
285 |
model = global_model
|
286 |
+
criterion = nn.L1Loss() # L1 loss for image reconstruction
|
287 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
288 |
output_text = []
|
289 |
|
290 |
for epoch in range(epochs):
|
291 |
model.train()
|
292 |
+
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
|
293 |
+
# Move images and prompt embeddings to the appropriate device (CPU or GPU)
|
294 |
original, target = original.to(device), target.to(device)
|
295 |
+
original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
|
296 |
+
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
|
297 |
+
|
298 |
optimizer.zero_grad()
|
299 |
+
|
300 |
+
# Forward pass through the model
|
301 |
output = model(target)
|
302 |
+
|
303 |
+
# Compute image reconstruction loss
|
304 |
+
img_loss = criterion(output, original)
|
305 |
+
|
306 |
+
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
307 |
+
prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
308 |
+
|
309 |
+
# Combine losses
|
310 |
+
total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
|
311 |
+
total_loss.backward()
|
312 |
+
|
313 |
+
# Optimizer step
|
314 |
optimizer.step()
|
315 |
|
316 |
if i % 10 == 0:
|
317 |
+
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
318 |
print(status)
|
319 |
output_text.append(status)
|
320 |
|
321 |
+
# Push model to Hugging Face Hub at the end of each epoch
|
322 |
to_hub(model)
|
323 |
|
324 |
+
global_model = model # Update the global model after training
|
325 |
return model, "\n".join(output_text)
|
326 |
|
|
|
327 |
def gradio_train(epochs):
|
328 |
"""Gradio training interface function"""
|
329 |
model, training_log = train_model(int(epochs))
|
330 |
to_hub(model)
|
331 |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
|
332 |
|
333 |
+
def gradio_inference(input_image, keywords):
|
334 |
+
"""Gradio inference interface function"""
|
335 |
+
# Generate an enhanced prompt using the chat bot
|
336 |
+
enhanced_prompt = chat_with_bot(keywords)
|
337 |
+
|
338 |
+
# Run inference on the input image
|
339 |
+
output_image = run_inference(input_image, chat_with_bot(keywords))
|
340 |
+
|
341 |
+
return input_image, output_image, keywords, enhanced_prompt
|
342 |
def gradio_inference(input_image):
|
343 |
"""Gradio inference interface function"""
|
344 |
return input_image, run_inference(input_image)
|