projectai / app.py
Matthew Frazer
Update app.py
76c04c7 verified
import sys
import torch
import gradio as gr
from pathlib import Path
# Add current directory to Python path
sys.path.insert(0, str(Path(__file__).parent))
# Local imports
from models.vae_gan import VAEEncoder, GANDecoder
from models.clip_feedback import CLIPFeedback
class ImageGenerator:
def __init__(self, device='cpu'):
self.device = device
self.encoder = VAEEncoder().to(device)
self.decoder = GANDecoder().to(device)
self.clip_fb = CLIPFeedback(device)
self.load_weights()
def load_weights(self):
base_path = Path(__file__).parent
encoder_path = base_path / 'storage' / 'models' / 'encoder.pth'
decoder_path = base_path / 'storage' / 'models' / 'decoder.pth'
if encoder_path.exists():
self.encoder.load_state_dict(torch.load(encoder_path, map_location=self.device))
self.decoder.load_state_dict(torch.load(decoder_path, map_location=self.device))
def generate_image(self, text, iterations=3):
z = torch.randn(1, 512).to(self.device)
for _ in range(iterations):
z = self.clip_fb.refine_latent(z, self.decoder, text)
with torch.no_grad():
image = self.decoder(z)
return image.cpu().squeeze().permute(1,2,0).numpy()
interface = gr.Interface(
fn=ImageGenerator().generate_image,
inputs=gr.Textbox(label="Image Description"),
outputs=gr.Image(label="Generated Image"),
title="Multimodal AI Image Generator"
)
if __name__ == "__main__":
interface.launch()