import sys import torch import gradio as gr from pathlib import Path # Add current directory to Python path sys.path.insert(0, str(Path(__file__).parent)) # Local imports from models.vae_gan import VAEEncoder, GANDecoder from models.clip_feedback import CLIPFeedback class ImageGenerator: def __init__(self, device='cpu'): self.device = device self.encoder = VAEEncoder().to(device) self.decoder = GANDecoder().to(device) self.clip_fb = CLIPFeedback(device) self.load_weights() def load_weights(self): base_path = Path(__file__).parent encoder_path = base_path / 'storage' / 'models' / 'encoder.pth' decoder_path = base_path / 'storage' / 'models' / 'decoder.pth' if encoder_path.exists(): self.encoder.load_state_dict(torch.load(encoder_path, map_location=self.device)) self.decoder.load_state_dict(torch.load(decoder_path, map_location=self.device)) def generate_image(self, text, iterations=3): z = torch.randn(1, 512).to(self.device) for _ in range(iterations): z = self.clip_fb.refine_latent(z, self.decoder, text) with torch.no_grad(): image = self.decoder(z) return image.cpu().squeeze().permute(1,2,0).numpy() interface = gr.Interface( fn=ImageGenerator().generate_image, inputs=gr.Textbox(label="Image Description"), outputs=gr.Image(label="Generated Image"), title="Multimodal AI Image Generator" ) if __name__ == "__main__": interface.launch()