imagin_1B / model.py
Khelendramee's picture
Create model.py
288bf40 verified
import torch
import torch.nn as nn
import cv2
import os
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from datetime import datetime
from scipy.fftpack import dct
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
class UNetGenerator(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNetGenerator, self).__init__()
# Encoder (downsampling)
self.down1 = self.conv_block(in_channels, 64, down=True)
self.down2 = self.conv_block(64, 128, down=True)
self.down3 = self.conv_block(128, 256, down=True)
# Bottleneck
self.bottleneck = self.conv_block(256, 512)
# Decoder (upsampling)
self.up3 = self.conv_block(512 + 256, 256, up=True)
self.up2 = self.conv_block(256 + 128, 128, up=True)
self.up1 = self.conv_block(128 + 64, 64, up=True)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def conv_block(self, in_ch, out_ch, down=False, up=False):
layers = []
if down:
layers.append(nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False))
elif up:
layers.append(nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False))
else:
layers.append(nn.Conv2d(in_ch, out_ch, 3, 1, 1, bias=False))
layers.extend([
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
])
return nn.Sequential(*layers)
def forward(self, x):
# Encoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
# Bottleneck
bottleneck = self.bottleneck(d3)
# Decoder
u3 = self.up3(torch.cat([bottleneck, d3], dim=1))
u2 = self.up2(torch.cat([u3, d2], dim=1))
u1 = self.up1(torch.cat([u2, d1], dim=1))
return self.final_conv(u1)
class Discriminator(nn.Module):
def __init__(self, in_channels):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
self.conv_block(in_channels * 2, 64, norm=False),
self.conv_block(64, 128),
self.conv_block(128, 256),
self.conv_block(256, 512),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
)
def conv_block(self, in_ch, out_ch, norm=True):
layers = [nn.Conv2d(in_ch, out_ch, 4, stride=2, padding=1, bias=False)]
if norm:
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layers)
def forward(self, x, y):
return self.model(torch.cat([x, y], dim=1))
# Initialize models
generator = UNetGenerator(3, 3)
discriminator = Discriminator(3)
# Loss functions
adversarial_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Training loop (simplified)
def train(input_image, target_image, i):
# Train Discriminator
optimizer_D.zero_grad()
real_output = discriminator(input_image, target_image)
fake_image = generator(input_image)
fake_output = discriminator(input_image, fake_image.detach())
d_loss = (adversarial_loss(fake_output, torch.zeros_like(fake_output)) +
adversarial_loss(real_output, torch.ones_like(real_output))) / 2
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
fake_output = discriminator(input_image, fake_image)
g_loss = adversarial_loss(fake_output, torch.ones_like(fake_output)) + \
100 * l1_loss(fake_image, target_image) # L1 loss weight
g_loss.backward()
optimizer_G.step()
return i, d_loss.item(), g_loss.item()
def model_save():
item_path = os.path.join('.', 'model')
if not os.path.isdir(item_path):
os.makedirs('model')
final_generator_path = os.path.join('model/', 'generator.pth')
final_discriminator_path = os.path.join('model/', 'discriminator.pth')
torch.save(generator.state_dict(), final_generator_path)
torch.save(discriminator.state_dict(), final_discriminator_path)
print("Final model saved")
def extract_frames(video_path, output_dir):
# Create the output directory if it doesn't exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Load the video
cap = cv2.VideoCapture(video_path)
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Save frame as an image
frame = cv2.resize(frame,(256,256))
frame_filename = os.path.join(output_dir, f"frame_{frame_count:05d}.jpg")
cv2.imwrite(frame_filename, frame)
frame_count += 1
# Release the video capture object
cap.release()
class Pix2PixDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.input_dir = os.path.join(root_dir, 'input')
self.target_dir = os.path.join(root_dir, 'target')
self.image_files = os.listdir(self.input_dir)
def __len__(self):
return len(self.image_files)
# This method was defined outside of the class
def __getitem__(self, idx):
img_name = self.image_files[idx]
input_path = os.path.join(self.input_dir, img_name)
target_path = os.path.join(self.target_dir, img_name)
input_image = Image.open(input_path).convert('RGB')
target_image = Image.open(target_path).convert('RGB')
if self.transform:
input_image = self.transform(input_image)
target_image = self.transform(target_image)
return input_image,target_image
def __getitem__(self, idx):
img_name = self.image_files[idx]
input_path = os.path.join(self.input_dir, img_name)
target_path = os.path.join(self.target_dir, img_name)
input_image = Image.open(input_path).convert('RGB')
target_image = Image.open(target_path).convert('RGB')
if self.transform:
input_image = self.transform(input_image)
target_image = self.transform(target_image)
return input_image, target_image # दोनों इमेज को लौटाएं
def model_train(epochs):
# JSON फ़ाइल को पढ़ना
with open('text/input.json', 'r') as file:
data = json.load(file)
for i in range(len(data)):
text_to_noise(data[i]).save(f'data/input/frame_{i:05d}.jpg')
# JSON डेटा से इमेज जनरेट करना
#for i in range(len(data)):
#text_to_noise(data[i]).save(f'data/input/{i}_image.jpg')
# डेटासेट लोड करना
dataset = Pix2PixDataset(root_dir='data/', transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# ट्रेनिंग लूप
for epoch in range(epochs):
for i, (input_image, output_image) in enumerate(dataloader):
i, d_loss, g_loss = train(input_image, output_image, i)
print(f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}], D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}")
def model_dataset(video_path):
item_path = os.path.join('.', 'data')
if not os.path.isdir(item_path):
os.makedirs('data')
extract_frames(video_path,'data/input')
extract_frames(video_path,'data/target')
def text_to_noise(text):
numerical_values = [ord(char) for char in text]
numerical_values += [0] * (256*256 - len(numerical_values))
# 2D DCT
dct_coefficients = np.reshape(numerical_values, (256, 256))
dct_coefficients = dct(dct_coefficients, axis=0)
dct_coefficients = dct(dct_coefficients, axis=1)
# 3D tensor mein reshape
noise_image = np.repeat(dct_coefficients[:, :, np.newaxis], 3, axis=2)
return Image.fromarray((noise_image * 255).astype(np.uint8))
def model_generate(text):
# लोड प्रशिक्षित मॉडल
generator.load_state_dict(torch.load('model/generator.pth'))
generator.eval() # मूल्यांकन मोड में सेट करें
input_image = text_to_noise(text)
input_image = input_image.convert('RGB')
input_tensor = transform(input_image).unsqueeze(0) # बैच डाइमेंशन जोड़ें
# इमेज जनरेट करें
with torch.no_grad():
output_tensor = generator(input_tensor)
# आउटपुट टेंसर को इमेज में कनवर्ट करें
output_image = transforms.ToPILImage()(output_tensor.squeeze(0) * 0.5 + 0.5)
item_path = os.path.join('.', 'generated')
if not os.path.isdir(item_path):
os.makedirs('generated')
# आउटपुट इमेज सेव कर
output_image.save(f'generated/{str(datetime.now())}.jpg')
print(f"जनरेटेड इमेज सेव की गई: f'{str(datetime.now())}.jpg")