Update app.py
Browse files
app.py
CHANGED
@@ -1,47 +1,25 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
-
import numpy as np
|
5 |
-
import torch.optim as optim
|
6 |
import torchvision.transforms as transforms
|
7 |
-
from torch.utils.data import DataLoader, Dataset
|
8 |
from PIL import Image
|
9 |
-
from
|
10 |
-
from transformers import T5EncoderModel, T5Tokenizer, DistilBertModel, DistilBertTokenizer
|
11 |
import matplotlib.pyplot as plt
|
12 |
-
|
13 |
-
from credits import HUGGINGFACE_TOKEN, HUGGINGFACE_W_TOKEN, WANDB_API_KEY
|
14 |
-
import wandb # Import wandb
|
15 |
-
import torchvision.utils as vutils # To save image grids
|
16 |
-
write_token = HUGGINGFACE_W_TOKEN
|
17 |
-
read_token = HUGGINGFACE_TOKEN
|
18 |
-
|
19 |
class TextEncoder(nn.Module):
|
20 |
-
def __init__(self, encoder_model_name
|
21 |
super(TextEncoder, self).__init__()
|
22 |
-
self.
|
23 |
-
self.
|
24 |
-
|
25 |
-
self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
|
26 |
-
self.encoder = T5EncoderModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
|
27 |
-
elif encoder_type == "distilbert":
|
28 |
-
self.tokenizer = DistilBertTokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
|
29 |
-
self.encoder = DistilBertModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
|
30 |
-
else:
|
31 |
-
raise ValueError(f"Invalid encoder_type: {encoder_type}. Choose from 't5' or 'distilbert'.")
|
32 |
-
self.encoder.to(self.device)
|
33 |
|
34 |
def encode_text(self, text):
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
embeddings = []
|
39 |
-
|
40 |
-
inputs = self.tokenizer(t, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
41 |
-
outputs = self.encoder(**inputs)
|
42 |
-
embeddings.append(outputs.last_hidden_state[:, 0, :])
|
43 |
-
|
44 |
-
return torch.stack(embeddings) # Combine embeddings into a batch
|
45 |
|
46 |
class ConditionalDiffusionModel(nn.Module):
|
47 |
def __init__(self):
|
@@ -81,233 +59,110 @@ class TextToImageModel(nn.Module):
|
|
81 |
def forward(self, text):
|
82 |
text_embeddings = self.text_encoder.encode_text(text)
|
83 |
image_embeddings = self.conditional_diffusion_model(text_embeddings)
|
84 |
-
input_image = torch.rand((
|
85 |
-
for
|
86 |
input_image = self.super_resolution_diffusion_model(input_image)
|
87 |
return input_image
|
88 |
|
89 |
class CustomDataset(Dataset):
|
90 |
-
def __init__(self, annotations_file,
|
91 |
with open(annotations_file, 'r') as f:
|
92 |
lines = f.readlines()
|
93 |
-
self.transform = transforms.Compose([
|
94 |
-
transforms.Resize((size_sqr, size_sqr)),
|
95 |
-
transforms.ToTensor(),
|
96 |
-
])
|
97 |
self.img_labels = [line.strip().split(' ', 1) for line in lines]
|
98 |
-
self.
|
|
|
99 |
|
100 |
def __len__(self):
|
101 |
return len(self.img_labels)
|
102 |
|
103 |
def __getitem__(self, idx):
|
104 |
img_name, text = self.img_labels[idx]
|
105 |
-
img_path = os.path.join(self.
|
106 |
image = Image.open(img_path).convert("RGB")
|
107 |
if self.transform:
|
108 |
image = self.transform(image)
|
109 |
return text, image
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
transforms.ToTensor(),
|
118 |
-
])
|
119 |
-
# Apply limit if specified
|
120 |
-
if limit is not None:
|
121 |
-
self.dataset = Subset(self.dataset["train"], range(limit))
|
122 |
-
|
123 |
-
def __len__(self):
|
124 |
-
return len(self.dataset["train"])
|
125 |
-
|
126 |
-
def __getitem__(self, idx):
|
127 |
-
item = self.dataset["train"][idx]
|
128 |
-
|
129 |
-
image =item["image"].convert("RGB")
|
130 |
-
text = item["subject"]
|
131 |
-
if self.transform:
|
132 |
-
image = self.transform(image)
|
133 |
-
return text, image
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
class StorageHandler:
|
138 |
-
def __init__(self, storage_dir="./image_gen_storage", hub_model_name="K00B404/tiny_image_gen", push_dataset=False, dataset_name="K00B404/custom_image_descriptions_dataset"):
|
139 |
-
self.model_name = hub_model_name
|
140 |
-
self.dataset_name = dataset_name
|
141 |
-
self.push_dataset = push_dataset
|
142 |
-
self.storage_dir = storage_dir
|
143 |
-
|
144 |
-
def save_checkpoint(self, model, optimizer, scheduler, epoch, checkpoint_path):
|
145 |
-
checkpoint = {
|
146 |
-
'model_state_dict': model.state_dict(),
|
147 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
148 |
-
'scheduler_state_dict': scheduler.state_dict(),
|
149 |
-
'epoch': epoch
|
150 |
-
}
|
151 |
-
torch.save(checkpoint, checkpoint_path)
|
152 |
-
|
153 |
-
def load_checkpoint(self, model, optimizer, scheduler, checkpoint_path):
|
154 |
-
if os.path.isfile(checkpoint_path):
|
155 |
-
checkpoint = torch.load(checkpoint_path)
|
156 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
157 |
-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
158 |
-
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
159 |
-
epoch = checkpoint['epoch']
|
160 |
-
return epoch, scheduler
|
161 |
-
else:
|
162 |
-
return 0, scheduler
|
163 |
-
|
164 |
-
def push_dataset(self, dataset):
|
165 |
-
if dataset:
|
166 |
-
dataset.push_to_hub(self.dataset_name, token=write_token)
|
167 |
-
|
168 |
-
def push(self, model, tokenizer):
|
169 |
-
model.push_to_hub(self.model_name, token=write_token)
|
170 |
-
tokenizer.push_to_hub(self.model_name, token=write_token)
|
171 |
-
# Optionally push dataset to Hugging Face Hub
|
172 |
-
|
173 |
-
class Common:
|
174 |
-
def __init__(self, device='cpu', wandb_log=False):
|
175 |
-
self.wandb_log = wandb_log
|
176 |
-
self.device = device
|
177 |
-
self.terminal_log = rp
|
178 |
-
|
179 |
-
if self.wandb_log:
|
180 |
-
# Initialize wandb
|
181 |
-
#self.wandb = wandb.login(key=WANDB_API_KEY) # Assuming you have already logged in. If not, use: wandb.login(key='YOUR_WANDB_API_KEY')
|
182 |
-
self.wandb = wandb.init(project="my-image-generation-project",
|
183 |
-
config={
|
184 |
-
"learning_rate": learning_rate,
|
185 |
-
"batch_size": batch_size,
|
186 |
-
"num_epochs": num_epochs,
|
187 |
-
"encoder_model": encoder
|
188 |
-
})
|
189 |
-
|
190 |
-
|
191 |
-
def train(self, model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_path):
|
192 |
-
for epoch in range(start_epoch, num_epochs):
|
193 |
-
model.train()
|
194 |
-
for i, (text_batch, image_batch) in enumerate(dataloader):
|
195 |
-
image_batch = image_batch.to(self.device)
|
196 |
-
image_size = 128 # Increase image size
|
197 |
-
# Generate a starting image with the correct shape
|
198 |
-
input_image = torch.randn((image_batch.shape[0], 3, image_size//8, image_size//8)).to(device)
|
199 |
-
|
200 |
-
optimizer.zero_grad()
|
201 |
-
images = model(text_batch)
|
202 |
-
loss = criterion(images, image_batch)
|
203 |
-
loss.backward()
|
204 |
-
optimizer.step()
|
205 |
-
|
206 |
-
if self.wandb_log:
|
207 |
-
# Log loss and learning rate
|
208 |
-
self.wandb.log({"train_loss": loss.item(), "lr": optimizer.param_groups[0]['lr']})
|
209 |
-
|
210 |
-
if i % 25 == 0:
|
211 |
-
# Save a grid of real and generated images for monitoring
|
212 |
-
img_grid_real = vutils.make_grid(image_batch[:4], padding=2, normalize=True)
|
213 |
-
img_grid_fake = vutils.make_grid(input_image[:4], padding=2, normalize=True)
|
214 |
-
|
215 |
-
plt.figure(figsize=(15,15))
|
216 |
-
plt.subplot(1,2,1)
|
217 |
-
plt.axis("off")
|
218 |
-
plt.title("Real Images")
|
219 |
-
plt.imshow(np.transpose(img_grid_real.cpu(),(1,2,0)))
|
220 |
-
|
221 |
-
plt.subplot(1,2,2)
|
222 |
-
plt.axis("off")
|
223 |
-
plt.title("Generated Images")
|
224 |
-
plt.imshow(np.transpose(img_grid_fake.cpu(),(1,2,0)))
|
225 |
-
plt.savefig(f'generated_images_epoch_{epoch+1}_batch_{i}.png')
|
226 |
-
plt.close()
|
227 |
-
|
228 |
-
# Validation step
|
229 |
-
val_loss = self.evaluate(model, dataloader, criterion)
|
230 |
-
scheduler.step(val_loss) # Update scheduler with validation loss
|
231 |
-
|
232 |
-
image = self.test_inference(model, "A house next to a river.")
|
233 |
-
self.visualize_image(image, f'generated_image_epoch_{epoch + 1}.png')
|
234 |
-
self.terminal_log(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Loss: {val_loss}')
|
235 |
-
StorageHandler().save_checkpoint(model, optimizer, scheduler, epoch + 1, checkpoint_path)
|
236 |
-
|
237 |
-
self.terminal_log("Training completed.")
|
238 |
-
|
239 |
-
def evaluate(self, model, dataloader, criterion):
|
240 |
-
model.eval()
|
241 |
-
total_loss = 0
|
242 |
-
with torch.no_grad():
|
243 |
-
for i, (text_batch, image_batch) in enumerate(dataloader):
|
244 |
-
image_batch = image_batch.to(self.device)
|
245 |
-
images = model(text_batch)
|
246 |
-
loss = criterion(images, image_batch)
|
247 |
-
total_loss += loss.item()
|
248 |
-
|
249 |
-
avg_loss = total_loss / len(dataloader)
|
250 |
-
self.terminal_log(f'Validation Loss: {avg_loss}')
|
251 |
-
return avg_loss
|
252 |
-
|
253 |
-
def test_inference(self, model, text):
|
254 |
-
model.eval()
|
255 |
-
with torch.no_grad():
|
256 |
-
if isinstance(text, str):
|
257 |
-
generated_image = model(text)
|
258 |
-
else:
|
259 |
-
generated_image = [model(t) for t in text]
|
260 |
-
return generated_image
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
#plt.show()
|
269 |
|
270 |
if __name__ == "__main__":
|
271 |
-
|
|
|
272 |
learning_rate = 1e-4
|
273 |
-
num_epochs =
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
train_img_dir = './train_images'
|
281 |
-
annotations_file = f'{train_img_dir}/annotations.txt'
|
282 |
-
storage_dir = "./image_gen_storage"
|
283 |
-
os.makedirs(storage_dir, exist_ok=True)
|
284 |
-
hub_model_name = "K00B404/tiny_image_gen"
|
285 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
286 |
-
text_encoder = TextEncoder(encoder, encoder_type="t5", device=device)
|
287 |
conditional_diffusion_model = ConditionalDiffusionModel()
|
288 |
super_resolution_diffusion_model = SuperResolutionDiffusionModel()
|
289 |
text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)
|
290 |
-
text_to_image_model.to(device)
|
291 |
|
292 |
-
optimizer
|
|
|
293 |
criterion = nn.MSELoss()
|
294 |
-
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.001, patience=2)
|
295 |
|
296 |
-
|
297 |
-
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
dataset = CustomDataset(annotations_file, train_img_dir, size_sqr=128)
|
305 |
|
|
|
|
|
306 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
|
|
4 |
import torchvision.transforms as transforms
|
5 |
+
from torch.utils.data import DataLoader, Dataset
|
6 |
from PIL import Image
|
7 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
+
device ="cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class TextEncoder(nn.Module):
|
11 |
+
def __init__(self, encoder_model_name):
|
12 |
super(TextEncoder, self).__init__()
|
13 |
+
self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name)
|
14 |
+
self.encoder = T5ForConditionalGeneration.from_pretrained(encoder_model_name)
|
15 |
+
self.encoder.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def encode_text(self, text):
|
18 |
+
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
19 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
20 |
+
outputs = self.encoder.encoder(**inputs)
|
21 |
+
embeddings = outputs.last_hidden_state[:, 0, :]
|
22 |
+
return embeddings
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
class ConditionalDiffusionModel(nn.Module):
|
25 |
def __init__(self):
|
|
|
59 |
def forward(self, text):
|
60 |
text_embeddings = self.text_encoder.encode_text(text)
|
61 |
image_embeddings = self.conditional_diffusion_model(text_embeddings)
|
62 |
+
input_image = torch.rand((1, 3, 64, 64)) # Initialize input image with random values
|
63 |
+
for i in range(6): # Upsample the image 6 times
|
64 |
input_image = self.super_resolution_diffusion_model(input_image)
|
65 |
return input_image
|
66 |
|
67 |
class CustomDataset(Dataset):
|
68 |
+
def __init__(self, annotations_file, img_dir, transform=None):
|
69 |
with open(annotations_file, 'r') as f:
|
70 |
lines = f.readlines()
|
|
|
|
|
|
|
|
|
71 |
self.img_labels = [line.strip().split(' ', 1) for line in lines]
|
72 |
+
self.img_dir = img_dir
|
73 |
+
self.transform = transform
|
74 |
|
75 |
def __len__(self):
|
76 |
return len(self.img_labels)
|
77 |
|
78 |
def __getitem__(self, idx):
|
79 |
img_name, text = self.img_labels[idx]
|
80 |
+
img_path = os.path.join(self.img_dir, img_name)
|
81 |
image = Image.open(img_path).convert("RGB")
|
82 |
if self.transform:
|
83 |
image = self.transform(image)
|
84 |
return text, image
|
85 |
|
86 |
+
def save_checkpoint(model, optimizer, epoch, checkpoint_path):
|
87 |
+
checkpoint = {
|
88 |
+
'model_state_dict': model.state_dict(),
|
89 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
90 |
+
'epoch': epoch
|
91 |
+
}
|
92 |
+
torch.save(checkpoint, checkpoint_path)
|
93 |
+
|
94 |
+
def load_checkpoint(model, optimizer, checkpoint_path):
|
95 |
+
if os.path.isfile(checkpoint_path):
|
96 |
+
checkpoint = torch.load(checkpoint_path)
|
97 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
98 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
99 |
+
epoch = checkpoint['epoch']
|
100 |
+
return epoch
|
101 |
+
else:
|
102 |
+
return 0
|
103 |
|
104 |
+
def test_inference(model, text):
|
105 |
+
model.eval()
|
106 |
+
with torch.no_grad():
|
107 |
+
generated_image = model(text)
|
108 |
+
return generated_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
+
def visualize_image(image_tensor):
|
111 |
+
image_tensor = image_tensor.squeeze(0).cpu().detach()
|
112 |
+
image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1]
|
113 |
+
image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C)
|
114 |
+
plt.imshow(image_tensor)
|
115 |
+
plt.show()
|
|
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
+
# Define hyperparameters and paths
|
119 |
+
batch_size = 4
|
120 |
learning_rate = 1e-4
|
121 |
+
num_epochs = 1000
|
122 |
+
checkpoint_path = 'checkpoint.pth'
|
123 |
+
annotations_file = 'annotations.txt'
|
124 |
+
img_dir = 'images/'
|
125 |
+
|
126 |
+
# Initialize models
|
127 |
+
text_encoder = TextEncoder("google-t5/t5-small")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
conditional_diffusion_model = ConditionalDiffusionModel()
|
129 |
super_resolution_diffusion_model = SuperResolutionDiffusionModel()
|
130 |
text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)
|
|
|
131 |
|
132 |
+
# Define optimizer and criterion
|
133 |
+
optimizer = torch.optim.Adam(text_to_image_model.parameters(), lr=learning_rate)
|
134 |
criterion = nn.MSELoss()
|
|
|
135 |
|
136 |
+
# Load checkpoint if available
|
137 |
+
start_epoch = load_checkpoint(text_to_image_model, optimizer, checkpoint_path)
|
138 |
|
139 |
+
# Define transformations for the images
|
140 |
+
transform = transforms.Compose([
|
141 |
+
transforms.Resize((64, 64)),
|
142 |
+
transforms.ToTensor(),
|
143 |
+
])
|
|
|
144 |
|
145 |
+
# Initialize dataset and dataloader
|
146 |
+
dataset = CustomDataset(annotations_file, img_dir, transform)
|
147 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
148 |
|
149 |
+
# Training loop
|
150 |
+
text_to_image_model.train()
|
151 |
+
for epoch in range(start_epoch, num_epochs):
|
152 |
+
for i, (text_batch, image_batch) in enumerate(dataloader):
|
153 |
+
optimizer.zero_grad()
|
154 |
+
images = text_to_image_model(text_batch)
|
155 |
+
target_images = image_batch.to(device)
|
156 |
+
loss = criterion(images, target_images)
|
157 |
+
loss.backward()
|
158 |
+
optimizer.step()
|
159 |
+
|
160 |
+
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
|
161 |
+
save_checkpoint(text_to_image_model, optimizer, epoch+1, checkpoint_path)
|
162 |
+
|
163 |
+
print("Training completed.")
|
164 |
+
|
165 |
+
# Test inference
|
166 |
+
sample_text = "A big ape."
|
167 |
+
generated_image = test_inference(text_to_image_model, sample_text)
|
168 |
+
visualize_image(generated_image)
|