File size: 1,566 Bytes
d394cfe
76c04c7
 
d394cfe
76c04c7
 
d394cfe
 
76c04c7
 
 
d39cf5c
 
76c04c7
d39cf5c
 
 
 
 
 
 
76c04c7
 
 
 
 
 
 
d394cfe
d39cf5c
 
 
 
 
 
 
 
76c04c7
 
 
 
 
 
d39cf5c
 
76c04c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import sys
import torch
import gradio as gr
from pathlib import Path

# Add current directory to Python path
sys.path.insert(0, str(Path(__file__).parent))

# Local imports
from models.vae_gan import VAEEncoder, GANDecoder
from models.clip_feedback import CLIPFeedback

class ImageGenerator:
    def __init__(self, device='cpu'):
        self.device = device
        self.encoder = VAEEncoder().to(device)
        self.decoder = GANDecoder().to(device)
        self.clip_fb = CLIPFeedback(device)
        self.load_weights()
        
    def load_weights(self):
        base_path = Path(__file__).parent
        encoder_path = base_path / 'storage' / 'models' / 'encoder.pth'
        decoder_path = base_path / 'storage' / 'models' / 'decoder.pth'
        
        if encoder_path.exists():
            self.encoder.load_state_dict(torch.load(encoder_path, map_location=self.device))
            self.decoder.load_state_dict(torch.load(decoder_path, map_location=self.device))
    
    def generate_image(self, text, iterations=3):
        z = torch.randn(1, 512).to(self.device)
        for _ in range(iterations):
            z = self.clip_fb.refine_latent(z, self.decoder, text)
        with torch.no_grad():
            image = self.decoder(z)
        return image.cpu().squeeze().permute(1,2,0).numpy()

interface = gr.Interface(
    fn=ImageGenerator().generate_image,
    inputs=gr.Textbox(label="Image Description"),
    outputs=gr.Image(label="Generated Image"),
    title="Multimodal AI Image Generator"
)

if __name__ == "__main__":
    interface.launch()