|
import sys |
|
import torch |
|
import gradio as gr |
|
from pathlib import Path |
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from models.vae_gan import VAEEncoder, GANDecoder |
|
from models.clip_feedback import CLIPFeedback |
|
|
|
class ImageGenerator: |
|
def __init__(self, device='cpu'): |
|
self.device = device |
|
self.encoder = VAEEncoder().to(device) |
|
self.decoder = GANDecoder().to(device) |
|
self.clip_fb = CLIPFeedback(device) |
|
self.load_weights() |
|
|
|
def load_weights(self): |
|
base_path = Path(__file__).parent |
|
encoder_path = base_path / 'storage' / 'models' / 'encoder.pth' |
|
decoder_path = base_path / 'storage' / 'models' / 'decoder.pth' |
|
|
|
if encoder_path.exists(): |
|
self.encoder.load_state_dict(torch.load(encoder_path, map_location=self.device)) |
|
self.decoder.load_state_dict(torch.load(decoder_path, map_location=self.device)) |
|
|
|
def generate_image(self, text, iterations=3): |
|
z = torch.randn(1, 512).to(self.device) |
|
for _ in range(iterations): |
|
z = self.clip_fb.refine_latent(z, self.decoder, text) |
|
with torch.no_grad(): |
|
image = self.decoder(z) |
|
return image.cpu().squeeze().permute(1,2,0).numpy() |
|
|
|
interface = gr.Interface( |
|
fn=ImageGenerator().generate_image, |
|
inputs=gr.Textbox(label="Image Description"), |
|
outputs=gr.Image(label="Generated Image"), |
|
title="Multimodal AI Image Generator" |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |