File size: 6,373 Bytes
454a3ac
 
 
 
bfcb186
454a3ac
bfcb186
454a3ac
bfcb186
454a3ac
bfcb186
454a3ac
bfcb186
 
 
454a3ac
 
bfcb186
 
 
 
 
454a3ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfcb186
 
454a3ac
 
 
 
bfcb186
454a3ac
 
 
bfcb186
 
454a3ac
 
 
 
 
 
bfcb186
454a3ac
 
 
 
 
bfcb186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454a3ac
bfcb186
 
 
 
 
454a3ac
bfcb186
 
 
 
 
 
454a3ac
 
bfcb186
 
454a3ac
bfcb186
 
 
 
 
 
 
454a3ac
 
 
 
bfcb186
 
454a3ac
 
bfcb186
 
454a3ac
bfcb186
 
 
 
 
454a3ac
bfcb186
 
454a3ac
 
bfcb186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from transformers import T5ForConditionalGeneration, T5Tokenizer
import matplotlib.pyplot as plt
device ="cpu"
class TextEncoder(nn.Module):
    def __init__(self, encoder_model_name):
        super(TextEncoder, self).__init__()
        self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name)
        self.encoder = T5ForConditionalGeneration.from_pretrained(encoder_model_name)
        self.encoder.to(device)

    def encode_text(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        inputs = {key: value.to(device) for key, value in inputs.items()}
        outputs = self.encoder.encoder(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :]
        return embeddings

class ConditionalDiffusionModel(nn.Module):
    def __init__(self):
        super(ConditionalDiffusionModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(512, 768),  # Adjusted from 512 to 768
            nn.ReLU(),
            nn.Linear(768, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )

    def forward(self, text_embeddings):
        return self.model(text_embeddings)

class SuperResolutionDiffusionModel(nn.Module):
    def __init__(self):
        super(SuperResolutionDiffusionModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),  # 3 is the number of color channels
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, input_image):
        return self.model(input_image)

class TextToImageModel(nn.Module):
    def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model):
        super(TextToImageModel, self).__init__()
        self.text_encoder = text_encoder
        self.conditional_diffusion_model = conditional_diffusion_model
        self.super_resolution_diffusion_model = super_resolution_diffusion_model

    def forward(self, text):
        text_embeddings = self.text_encoder.encode_text(text)
        image_embeddings = self.conditional_diffusion_model(text_embeddings)
        input_image = torch.rand((1, 3, 64, 64))  # Initialize input image with random values
        for i in range(6):  # Upsample the image 6 times
            input_image = self.super_resolution_diffusion_model(input_image)
        return input_image

class CustomDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        with open(annotations_file, 'r') as f:
            lines = f.readlines()
        self.img_labels = [line.strip().split(' ', 1) for line in lines]
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, text = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return text, image

def save_checkpoint(model, optimizer, epoch, checkpoint_path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, checkpoint_path)

def load_checkpoint(model, optimizer, checkpoint_path):
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        return epoch
    else:
        return 0

def test_inference(model, text):
    model.eval()
    with torch.no_grad():
        generated_image = model(text)
    return generated_image

def visualize_image(image_tensor):
    image_tensor = image_tensor.squeeze(0).cpu().detach()
    image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())  # Normalize to [0, 1]
    image_tensor = image_tensor.permute(1, 2, 0)  # Change from (C, H, W) to (H, W, C)
    plt.imshow(image_tensor)
    plt.show()

if __name__ == "__main__":
    # Define hyperparameters and paths
    batch_size = 4
    learning_rate = 1e-4
    num_epochs = 1000
    checkpoint_path = 'checkpoint.pth'
    annotations_file = 'annotations.txt'
    img_dir = 'images/'
    
    # Initialize models
    text_encoder = TextEncoder("google-t5/t5-small")
    conditional_diffusion_model = ConditionalDiffusionModel()
    super_resolution_diffusion_model = SuperResolutionDiffusionModel()
    text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)

    # Define optimizer and criterion
    optimizer = torch.optim.Adam(text_to_image_model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    # Load checkpoint if available
    start_epoch = load_checkpoint(text_to_image_model, optimizer, checkpoint_path)

    # Define transformations for the images
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])

    # Initialize dataset and dataloader
    dataset = CustomDataset(annotations_file, img_dir, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Training loop
    text_to_image_model.train()
    for epoch in range(start_epoch, num_epochs):
        for i, (text_batch, image_batch) in enumerate(dataloader):
            optimizer.zero_grad()
            images = text_to_image_model(text_batch)
            target_images = image_batch.to(device)
            loss = criterion(images, target_images)
            loss.backward()
            optimizer.step()
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
        save_checkpoint(text_to_image_model, optimizer, epoch+1, checkpoint_path)

    print("Training completed.")
    
    # Test inference
    sample_text = "A big ape."
    generated_image = test_inference(text_to_image_model, sample_text)
    visualize_image(generated_image)