K00B404 commited on
Commit
bfcb186
1 Parent(s): 46fe3b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -231
app.py CHANGED
@@ -1,47 +1,25 @@
1
  import os
2
  import torch
3
  import torch.nn as nn
4
- import numpy as np
5
- import torch.optim as optim
6
  import torchvision.transforms as transforms
7
- from torch.utils.data import DataLoader, Dataset, Subset
8
  from PIL import Image
9
- from datasets import load_dataset
10
- from transformers import T5EncoderModel, T5Tokenizer, DistilBertModel, DistilBertTokenizer
11
  import matplotlib.pyplot as plt
12
- from rich import print as rp
13
- from credits import HUGGINGFACE_TOKEN, HUGGINGFACE_W_TOKEN, WANDB_API_KEY
14
- import wandb # Import wandb
15
- import torchvision.utils as vutils # To save image grids
16
- write_token = HUGGINGFACE_W_TOKEN
17
- read_token = HUGGINGFACE_TOKEN
18
-
19
  class TextEncoder(nn.Module):
20
- def __init__(self, encoder_model_name, encoder_type="t5", device='cpu'):
21
  super(TextEncoder, self).__init__()
22
- self.device = device
23
- self.encoder_type = encoder_type
24
- if encoder_type == "t5":
25
- self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
26
- self.encoder = T5EncoderModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
27
- elif encoder_type == "distilbert":
28
- self.tokenizer = DistilBertTokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
29
- self.encoder = DistilBertModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
30
- else:
31
- raise ValueError(f"Invalid encoder_type: {encoder_type}. Choose from 't5' or 'distilbert'.")
32
- self.encoder.to(self.device)
33
 
34
  def encode_text(self, text):
35
- if isinstance(text, str):
36
- text = [text]
37
-
38
- embeddings = []
39
- for t in text:
40
- inputs = self.tokenizer(t, return_tensors="pt", padding=True, truncation=True).to(self.device)
41
- outputs = self.encoder(**inputs)
42
- embeddings.append(outputs.last_hidden_state[:, 0, :])
43
-
44
- return torch.stack(embeddings) # Combine embeddings into a batch
45
 
46
  class ConditionalDiffusionModel(nn.Module):
47
  def __init__(self):
@@ -81,233 +59,110 @@ class TextToImageModel(nn.Module):
81
  def forward(self, text):
82
  text_embeddings = self.text_encoder.encode_text(text)
83
  image_embeddings = self.conditional_diffusion_model(text_embeddings)
84
- input_image = torch.rand((image_embeddings.shape[0], 3, 128, 128)).to(text_embeddings.device)
85
- for _ in range(6): # Upsample the image 6 times
86
  input_image = self.super_resolution_diffusion_model(input_image)
87
  return input_image
88
 
89
  class CustomDataset(Dataset):
90
- def __init__(self, annotations_file, train_img_dir, size_sqr=128):
91
  with open(annotations_file, 'r') as f:
92
  lines = f.readlines()
93
- self.transform = transforms.Compose([
94
- transforms.Resize((size_sqr, size_sqr)),
95
- transforms.ToTensor(),
96
- ])
97
  self.img_labels = [line.strip().split(' ', 1) for line in lines]
98
- self.train_img_dir = train_img_dir
 
99
 
100
  def __len__(self):
101
  return len(self.img_labels)
102
 
103
  def __getitem__(self, idx):
104
  img_name, text = self.img_labels[idx]
105
- img_path = os.path.join(self.train_img_dir, img_name)
106
  image = Image.open(img_path).convert("RGB")
107
  if self.transform:
108
  image = self.transform(image)
109
  return text, image
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- class HuggingDataset(Dataset): # New class for HuggingFace dataset
113
- def __init__(self, dataset_name="vera365/lexica_dataset", size_sqr=128, limit=None):
114
- self.dataset = load_dataset(dataset_name, token=read_token, cache_dir='./datasets')
115
- self.transform = transforms.Compose([
116
- transforms.Resize((size_sqr, size_sqr)),
117
- transforms.ToTensor(),
118
- ])
119
- # Apply limit if specified
120
- if limit is not None:
121
- self.dataset = Subset(self.dataset["train"], range(limit))
122
-
123
- def __len__(self):
124
- return len(self.dataset["train"])
125
-
126
- def __getitem__(self, idx):
127
- item = self.dataset["train"][idx]
128
-
129
- image =item["image"].convert("RGB")
130
- text = item["subject"]
131
- if self.transform:
132
- image = self.transform(image)
133
- return text, image
134
-
135
-
136
-
137
- class StorageHandler:
138
- def __init__(self, storage_dir="./image_gen_storage", hub_model_name="K00B404/tiny_image_gen", push_dataset=False, dataset_name="K00B404/custom_image_descriptions_dataset"):
139
- self.model_name = hub_model_name
140
- self.dataset_name = dataset_name
141
- self.push_dataset = push_dataset
142
- self.storage_dir = storage_dir
143
-
144
- def save_checkpoint(self, model, optimizer, scheduler, epoch, checkpoint_path):
145
- checkpoint = {
146
- 'model_state_dict': model.state_dict(),
147
- 'optimizer_state_dict': optimizer.state_dict(),
148
- 'scheduler_state_dict': scheduler.state_dict(),
149
- 'epoch': epoch
150
- }
151
- torch.save(checkpoint, checkpoint_path)
152
-
153
- def load_checkpoint(self, model, optimizer, scheduler, checkpoint_path):
154
- if os.path.isfile(checkpoint_path):
155
- checkpoint = torch.load(checkpoint_path)
156
- model.load_state_dict(checkpoint['model_state_dict'])
157
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
158
- scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
159
- epoch = checkpoint['epoch']
160
- return epoch, scheduler
161
- else:
162
- return 0, scheduler
163
-
164
- def push_dataset(self, dataset):
165
- if dataset:
166
- dataset.push_to_hub(self.dataset_name, token=write_token)
167
-
168
- def push(self, model, tokenizer):
169
- model.push_to_hub(self.model_name, token=write_token)
170
- tokenizer.push_to_hub(self.model_name, token=write_token)
171
- # Optionally push dataset to Hugging Face Hub
172
-
173
- class Common:
174
- def __init__(self, device='cpu', wandb_log=False):
175
- self.wandb_log = wandb_log
176
- self.device = device
177
- self.terminal_log = rp
178
-
179
- if self.wandb_log:
180
- # Initialize wandb
181
- #self.wandb = wandb.login(key=WANDB_API_KEY) # Assuming you have already logged in. If not, use: wandb.login(key='YOUR_WANDB_API_KEY')
182
- self.wandb = wandb.init(project="my-image-generation-project",
183
- config={
184
- "learning_rate": learning_rate,
185
- "batch_size": batch_size,
186
- "num_epochs": num_epochs,
187
- "encoder_model": encoder
188
- })
189
-
190
-
191
- def train(self, model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_path):
192
- for epoch in range(start_epoch, num_epochs):
193
- model.train()
194
- for i, (text_batch, image_batch) in enumerate(dataloader):
195
- image_batch = image_batch.to(self.device)
196
- image_size = 128 # Increase image size
197
- # Generate a starting image with the correct shape
198
- input_image = torch.randn((image_batch.shape[0], 3, image_size//8, image_size//8)).to(device)
199
-
200
- optimizer.zero_grad()
201
- images = model(text_batch)
202
- loss = criterion(images, image_batch)
203
- loss.backward()
204
- optimizer.step()
205
-
206
- if self.wandb_log:
207
- # Log loss and learning rate
208
- self.wandb.log({"train_loss": loss.item(), "lr": optimizer.param_groups[0]['lr']})
209
-
210
- if i % 25 == 0:
211
- # Save a grid of real and generated images for monitoring
212
- img_grid_real = vutils.make_grid(image_batch[:4], padding=2, normalize=True)
213
- img_grid_fake = vutils.make_grid(input_image[:4], padding=2, normalize=True)
214
-
215
- plt.figure(figsize=(15,15))
216
- plt.subplot(1,2,1)
217
- plt.axis("off")
218
- plt.title("Real Images")
219
- plt.imshow(np.transpose(img_grid_real.cpu(),(1,2,0)))
220
-
221
- plt.subplot(1,2,2)
222
- plt.axis("off")
223
- plt.title("Generated Images")
224
- plt.imshow(np.transpose(img_grid_fake.cpu(),(1,2,0)))
225
- plt.savefig(f'generated_images_epoch_{epoch+1}_batch_{i}.png')
226
- plt.close()
227
-
228
- # Validation step
229
- val_loss = self.evaluate(model, dataloader, criterion)
230
- scheduler.step(val_loss) # Update scheduler with validation loss
231
-
232
- image = self.test_inference(model, "A house next to a river.")
233
- self.visualize_image(image, f'generated_image_epoch_{epoch + 1}.png')
234
- self.terminal_log(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Loss: {val_loss}')
235
- StorageHandler().save_checkpoint(model, optimizer, scheduler, epoch + 1, checkpoint_path)
236
-
237
- self.terminal_log("Training completed.")
238
-
239
- def evaluate(self, model, dataloader, criterion):
240
- model.eval()
241
- total_loss = 0
242
- with torch.no_grad():
243
- for i, (text_batch, image_batch) in enumerate(dataloader):
244
- image_batch = image_batch.to(self.device)
245
- images = model(text_batch)
246
- loss = criterion(images, image_batch)
247
- total_loss += loss.item()
248
-
249
- avg_loss = total_loss / len(dataloader)
250
- self.terminal_log(f'Validation Loss: {avg_loss}')
251
- return avg_loss
252
-
253
- def test_inference(self, model, text):
254
- model.eval()
255
- with torch.no_grad():
256
- if isinstance(text, str):
257
- generated_image = model(text)
258
- else:
259
- generated_image = [model(t) for t in text]
260
- return generated_image
261
 
262
- def visualize_image(self, image_tensor, filename='generated_image.png'):
263
- image_tensor = image_tensor.squeeze(0).cpu().detach()
264
- image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1]
265
- image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C)
266
- plt.imshow(image_tensor)
267
- plt.savefig(filename)
268
- #plt.show()
269
 
270
  if __name__ == "__main__":
271
- batch_size = 1
 
272
  learning_rate = 1e-4
273
- num_epochs = 500
274
- encoder = "google-t5/t5-small"
275
- checkpoint_path = './models/image_gen'
276
- os.makedirs(checkpoint_path, exist_ok=True)
277
- checkpoint_file = f"{checkpoint_path}/checkpoint_backup.pth"
278
- use_huggingface_dataset = False # <-- Toggle between datasets
279
- limit_huggingface_dataset = 1000 # <-- Set the limit for HuggingFace dataset
280
- train_img_dir = './train_images'
281
- annotations_file = f'{train_img_dir}/annotations.txt'
282
- storage_dir = "./image_gen_storage"
283
- os.makedirs(storage_dir, exist_ok=True)
284
- hub_model_name = "K00B404/tiny_image_gen"
285
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
286
- text_encoder = TextEncoder(encoder, encoder_type="t5", device=device)
287
  conditional_diffusion_model = ConditionalDiffusionModel()
288
  super_resolution_diffusion_model = SuperResolutionDiffusionModel()
289
  text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)
290
- text_to_image_model.to(device)
291
 
292
- optimizer = optim.AdamW(text_to_image_model.parameters(), lr=learning_rate)
 
293
  criterion = nn.MSELoss()
294
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.001, patience=2)
295
 
296
- common = Common(device, True)
297
- storage = StorageHandler(checkpoint_path, hub_model_name)
298
 
299
- start_epoch, scheduler = storage.load_checkpoint(text_to_image_model, optimizer, scheduler, checkpoint_file)
300
-
301
- if use_huggingface_dataset:
302
- dataset = HuggingDataset(size_sqr=128, limit=limit_huggingface_dataset)
303
- else:
304
- dataset = CustomDataset(annotations_file, train_img_dir, size_sqr=128)
305
 
 
 
306
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
307
 
308
- common.train(text_to_image_model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_file)
309
-
310
- sample_texts = ["A big ape.", "A yellow banana."]
311
- for sample_text in sample_texts:
312
- generated_image = common.test_inference(text_to_image_model, sample_text)
313
- common.visualize_image(generated_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import torch.nn as nn
 
 
4
  import torchvision.transforms as transforms
5
+ from torch.utils.data import DataLoader, Dataset
6
  from PIL import Image
7
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
 
8
  import matplotlib.pyplot as plt
9
+ device ="cpu"
 
 
 
 
 
 
10
  class TextEncoder(nn.Module):
11
+ def __init__(self, encoder_model_name):
12
  super(TextEncoder, self).__init__()
13
+ self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name)
14
+ self.encoder = T5ForConditionalGeneration.from_pretrained(encoder_model_name)
15
+ self.encoder.to(device)
 
 
 
 
 
 
 
 
16
 
17
  def encode_text(self, text):
18
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
19
+ inputs = {key: value.to(device) for key, value in inputs.items()}
20
+ outputs = self.encoder.encoder(**inputs)
21
+ embeddings = outputs.last_hidden_state[:, 0, :]
22
+ return embeddings
 
 
 
 
 
23
 
24
  class ConditionalDiffusionModel(nn.Module):
25
  def __init__(self):
 
59
  def forward(self, text):
60
  text_embeddings = self.text_encoder.encode_text(text)
61
  image_embeddings = self.conditional_diffusion_model(text_embeddings)
62
+ input_image = torch.rand((1, 3, 64, 64)) # Initialize input image with random values
63
+ for i in range(6): # Upsample the image 6 times
64
  input_image = self.super_resolution_diffusion_model(input_image)
65
  return input_image
66
 
67
  class CustomDataset(Dataset):
68
+ def __init__(self, annotations_file, img_dir, transform=None):
69
  with open(annotations_file, 'r') as f:
70
  lines = f.readlines()
 
 
 
 
71
  self.img_labels = [line.strip().split(' ', 1) for line in lines]
72
+ self.img_dir = img_dir
73
+ self.transform = transform
74
 
75
  def __len__(self):
76
  return len(self.img_labels)
77
 
78
  def __getitem__(self, idx):
79
  img_name, text = self.img_labels[idx]
80
+ img_path = os.path.join(self.img_dir, img_name)
81
  image = Image.open(img_path).convert("RGB")
82
  if self.transform:
83
  image = self.transform(image)
84
  return text, image
85
 
86
+ def save_checkpoint(model, optimizer, epoch, checkpoint_path):
87
+ checkpoint = {
88
+ 'model_state_dict': model.state_dict(),
89
+ 'optimizer_state_dict': optimizer.state_dict(),
90
+ 'epoch': epoch
91
+ }
92
+ torch.save(checkpoint, checkpoint_path)
93
+
94
+ def load_checkpoint(model, optimizer, checkpoint_path):
95
+ if os.path.isfile(checkpoint_path):
96
+ checkpoint = torch.load(checkpoint_path)
97
+ model.load_state_dict(checkpoint['model_state_dict'])
98
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
99
+ epoch = checkpoint['epoch']
100
+ return epoch
101
+ else:
102
+ return 0
103
 
104
+ def test_inference(model, text):
105
+ model.eval()
106
+ with torch.no_grad():
107
+ generated_image = model(text)
108
+ return generated_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ def visualize_image(image_tensor):
111
+ image_tensor = image_tensor.squeeze(0).cpu().detach()
112
+ image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1]
113
+ image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C)
114
+ plt.imshow(image_tensor)
115
+ plt.show()
 
116
 
117
  if __name__ == "__main__":
118
+ # Define hyperparameters and paths
119
+ batch_size = 4
120
  learning_rate = 1e-4
121
+ num_epochs = 1000
122
+ checkpoint_path = 'checkpoint.pth'
123
+ annotations_file = 'annotations.txt'
124
+ img_dir = 'images/'
125
+
126
+ # Initialize models
127
+ text_encoder = TextEncoder("google-t5/t5-small")
 
 
 
 
 
 
 
128
  conditional_diffusion_model = ConditionalDiffusionModel()
129
  super_resolution_diffusion_model = SuperResolutionDiffusionModel()
130
  text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)
 
131
 
132
+ # Define optimizer and criterion
133
+ optimizer = torch.optim.Adam(text_to_image_model.parameters(), lr=learning_rate)
134
  criterion = nn.MSELoss()
 
135
 
136
+ # Load checkpoint if available
137
+ start_epoch = load_checkpoint(text_to_image_model, optimizer, checkpoint_path)
138
 
139
+ # Define transformations for the images
140
+ transform = transforms.Compose([
141
+ transforms.Resize((64, 64)),
142
+ transforms.ToTensor(),
143
+ ])
 
144
 
145
+ # Initialize dataset and dataloader
146
+ dataset = CustomDataset(annotations_file, img_dir, transform)
147
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
148
 
149
+ # Training loop
150
+ text_to_image_model.train()
151
+ for epoch in range(start_epoch, num_epochs):
152
+ for i, (text_batch, image_batch) in enumerate(dataloader):
153
+ optimizer.zero_grad()
154
+ images = text_to_image_model(text_batch)
155
+ target_images = image_batch.to(device)
156
+ loss = criterion(images, target_images)
157
+ loss.backward()
158
+ optimizer.step()
159
+
160
+ print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
161
+ save_checkpoint(text_to_image_model, optimizer, epoch+1, checkpoint_path)
162
+
163
+ print("Training completed.")
164
+
165
+ # Test inference
166
+ sample_text = "A big ape."
167
+ generated_image = test_inference(text_to_image_model, sample_text)
168
+ visualize_image(generated_image)