K00B404 commited on
Commit
852f11e
·
verified ·
1 Parent(s): aaa2b0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -11
app.py CHANGED
@@ -7,6 +7,7 @@ from datasets import load_dataset
7
  from huggingface_hub import Repository
8
  from huggingface_hub import HfApi, HfFolder, Repository, create_repo
9
  import os
 
10
  import gradio as gr
11
  from PIL import Image
12
  import numpy as np
@@ -52,10 +53,17 @@ def load_model():
52
  return model
53
 
54
 
 
 
55
 
56
- # Dataset class remains the same
57
  class Pix2PixDataset(torch.utils.data.Dataset):
58
- def __init__(self, ds, transform):
 
 
 
 
 
 
59
  self.originals = [x for x in ds["train"] if x['label'] == 0]
60
  self.targets = [x for x in ds["train"] if x['label'] == 1]
61
  assert len(self.originals) == len(self.targets)
@@ -67,8 +75,59 @@ class Pix2PixDataset(torch.utils.data.Dataset):
67
  return len(self.originals)
68
 
69
  def __getitem__(self, idx):
 
70
  original_img = self.originals[idx]['image']
71
  target_img = self.targets[idx]['image']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  original = original_img.convert('RGB')
73
  target = target_img.convert('RGB')
74
  return self.transform(original), self.transform(target)
@@ -188,7 +247,7 @@ def prepare_input(image, device='cpu'):
188
  input_tensor = transform(image).unsqueeze(0).to(device)
189
  return input_tensor
190
 
191
- def run_inference(image):
192
  """Run inference on a single image"""
193
  global global_model
194
  if global_model is None:
@@ -219,41 +278,67 @@ def train_model(epochs):
219
  transforms.ToTensor(),
220
  ])
221
 
222
- dataset = Pix2PixDataset(ds, transform)
 
223
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
224
 
225
  model = global_model
226
- criterion = nn.L1Loss()
227
  optimizer = optim.Adam(model.parameters(), lr=LR)
228
  output_text = []
229
 
230
  for epoch in range(epochs):
231
  model.train()
232
- for i, (original, target) in enumerate(dataloader):
 
233
  original, target = original.to(device), target.to(device)
 
 
 
234
  optimizer.zero_grad()
 
 
235
  output = model(target)
236
- loss = criterion(output, original)
237
- loss.backward()
 
 
 
 
 
 
 
 
 
 
238
  optimizer.step()
239
 
240
  if i % 10 == 0:
241
- status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
242
  print(status)
243
  output_text.append(status)
244
 
 
245
  to_hub(model)
246
 
247
- global_model = model
248
  return model, "\n".join(output_text)
249
 
250
-
251
  def gradio_train(epochs):
252
  """Gradio training interface function"""
253
  model, training_log = train_model(int(epochs))
254
  to_hub(model)
255
  return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
256
 
 
 
 
 
 
 
 
 
 
257
  def gradio_inference(input_image):
258
  """Gradio inference interface function"""
259
  return input_image, run_inference(input_image)
 
7
  from huggingface_hub import Repository
8
  from huggingface_hub import HfApi, HfFolder, Repository, create_repo
9
  import os
10
+ import pandas as pd
11
  import gradio as gr
12
  from PIL import Image
13
  import numpy as np
 
53
  return model
54
 
55
 
56
+ import os
57
+ import pandas as pd
58
 
 
59
  class Pix2PixDataset(torch.utils.data.Dataset):
60
+ def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
61
+ if not os.path.exists(csv_path):
62
+ os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
63
+
64
+ self.data = pd.read_csv(csv_path)
65
+ self.clip_tokenizer = clip_tokenizer
66
+
67
  self.originals = [x for x in ds["train"] if x['label'] == 0]
68
  self.targets = [x for x in ds["train"] if x['label'] == 1]
69
  assert len(self.originals) == len(self.targets)
 
75
  return len(self.originals)
76
 
77
  def __getitem__(self, idx):
78
+ # Get original and target images
79
  original_img = self.originals[idx]['image']
80
  target_img = self.targets[idx]['image']
81
+
82
+ # Convert PIL images
83
+ original = original_img.convert('RGB')
84
+ target = target_img.convert('RGB')
85
+
86
+ # Extract the filename from the original image's path (assuming it has a 'filename' field or path)
87
+ original_img_path = self.originals[idx]['image'].filename # Assuming it has this attribute
88
+ original_img_filename = os.path.basename(original_img_path)
89
+
90
+ # Match the image filename with the `image_path` column in the CSV
91
+ matched_row = self.data[self.data['image_path'].str.contains(original_img_filename)]
92
+
93
+ if matched_row.empty:
94
+ raise ValueError(f"No matching entry found in the CSV for image {original_img_filename}")
95
+
96
+ # Get the prompts from the matched row
97
+ original_prompt = matched_row['original_prompt'].values[0]
98
+ enhanced_prompt = matched_row['enhanced_prompt'].values[0]
99
+
100
+ # Tokenize the prompts using CLIP tokenizer
101
+ original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
102
+ enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
103
+
104
+ # Return transformed images and tokenized prompts
105
+ return self.transform(original), self.transform(target), original_tokens, enhanced_tokens
106
+
107
+
108
+ # Dataset class remains the same
109
+ class Pix2PixDataset_old(torch.utils.data.Dataset):
110
+ def __init__(self, ds, transform, csv_path='combined_data.csv'):
111
+ if not os.path.exists(csv_path):
112
+ os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
113
+
114
+ self.data = pd.read_csv(csv_path)
115
+ self.clip_tokenizer = clip_tokenizer
116
+
117
+ self.originals = [x for x in ds["train"] if x['label'] == 0]
118
+ self.targets = [x for x in ds["train"] if x['label'] == 1]
119
+ assert len(self.originals) == len(self.targets)
120
+ print(f"Number of original images: {len(self.originals)}")
121
+ print(f"Number of target images: {len(self.targets)}")
122
+ self.transform = transform
123
+
124
+ def __len__(self):
125
+ return len(self.originals)
126
+
127
+ def __getitem__(self, idx):
128
+ original_img = self.originals[idx]['image']
129
+ # TODO: get original_img file name and match with image_path in self.data....then tokenize the prompts with clip_tokenizer
130
+ target_img = self.targets[idx]['image']
131
  original = original_img.convert('RGB')
132
  target = target_img.convert('RGB')
133
  return self.transform(original), self.transform(target)
 
247
  input_tensor = transform(image).unsqueeze(0).to(device)
248
  return input_tensor
249
 
250
+ def run_inference(image, prompt):
251
  """Run inference on a single image"""
252
  global global_model
253
  if global_model is None:
 
278
  transforms.ToTensor(),
279
  ])
280
 
281
+ # Initialize the dataset and dataloader
282
+ dataset = Pix2PixDataset(ds, transform, clip_tokenizer)
283
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
284
 
285
  model = global_model
286
+ criterion = nn.L1Loss() # L1 loss for image reconstruction
287
  optimizer = optim.Adam(model.parameters(), lr=LR)
288
  output_text = []
289
 
290
  for epoch in range(epochs):
291
  model.train()
292
+ for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
293
+ # Move images and prompt embeddings to the appropriate device (CPU or GPU)
294
  original, target = original.to(device), target.to(device)
295
+ original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
296
+ enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
297
+
298
  optimizer.zero_grad()
299
+
300
+ # Forward pass through the model
301
  output = model(target)
302
+
303
+ # Compute image reconstruction loss
304
+ img_loss = criterion(output, original)
305
+
306
+ # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
307
+ prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
308
+
309
+ # Combine losses
310
+ total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
311
+ total_loss.backward()
312
+
313
+ # Optimizer step
314
  optimizer.step()
315
 
316
  if i % 10 == 0:
317
+ status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
318
  print(status)
319
  output_text.append(status)
320
 
321
+ # Push model to Hugging Face Hub at the end of each epoch
322
  to_hub(model)
323
 
324
+ global_model = model # Update the global model after training
325
  return model, "\n".join(output_text)
326
 
 
327
  def gradio_train(epochs):
328
  """Gradio training interface function"""
329
  model, training_log = train_model(int(epochs))
330
  to_hub(model)
331
  return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
332
 
333
+ def gradio_inference(input_image, keywords):
334
+ """Gradio inference interface function"""
335
+ # Generate an enhanced prompt using the chat bot
336
+ enhanced_prompt = chat_with_bot(keywords)
337
+
338
+ # Run inference on the input image
339
+ output_image = run_inference(input_image, chat_with_bot(keywords))
340
+
341
+ return input_image, output_image, keywords, enhanced_prompt
342
  def gradio_inference(input_image):
343
  """Gradio inference interface function"""
344
  return input_image, run_inference(input_image)