lazerkat commited on
Commit
3338afe
·
verified ·
1 Parent(s): 7ed64ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import urllib.request
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ # Global variables
12
+ model = None
13
+ checkpoint = None
14
+ device = None
15
+
16
+ # Download and load the model
17
+ def initialize_model():
18
+ global model, checkpoint, device
19
+
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ model_url = "https://huggingface.co/lazerkat/randomdiffusion/resolve/main/newest.pth"
23
+ model_path = "newest.pth"
24
+
25
+ # Download if not already present
26
+ if not os.path.exists(model_path):
27
+ gr.Info("Downloading model...")
28
+ urllib.request.urlretrieve(model_url, model_path)
29
+
30
+ # Load checkpoint
31
+ checkpoint = torch.load(model_path, map_location=device)
32
+
33
+ # Recreate the model architecture
34
+ from train import DiffusionUNet # Import directly from training script
35
+ model = DiffusionUNet(vocab_size=checkpoint['vocab_size']).to(device)
36
+ model.load_state_dict(checkpoint['model_state_dict'])
37
+ model.eval()
38
+
39
+ return "Model loaded successfully!"
40
+
41
+ # Generate image from prompt
42
+ def generate_image(prompt):
43
+ global model, checkpoint, device
44
+
45
+ if model is None:
46
+ return None, "Model not loaded yet. Please wait for initialization."
47
+
48
+ # Tokenize prompt using the saved vocab
49
+ vocab_data = checkpoint['word_to_idx']
50
+ max_len = 20
51
+ words = [w.strip('.,!?"\'') for w in prompt.lower().split()][:max_len]
52
+ indices = [vocab_data.get(w, 1) for w in words]
53
+ indices += [0] * (max_len - len(indices))
54
+ text_tokens = torch.tensor(indices).unsqueeze(0).to(device)
55
+
56
+ # Diffusion sampling
57
+ from train import Diffusion
58
+ diffusion = Diffusion(timesteps=500, device=device)
59
+
60
+ with torch.no_grad():
61
+ generated = diffusion.sample(model, text_tokens, image_size=64, batch_size=1)
62
+
63
+ # Convert to PIL image
64
+ image = generated.cpu().squeeze(0)
65
+ image = (image + 1) / 2
66
+ image = image.clamp(0, 1)
67
+ image = image.permute(1, 2, 0).numpy()
68
+ image = (image * 255).astype(np.uint8)
69
+ img = Image.fromarray(image)
70
+
71
+ return img, f"Generated image for: '{prompt}'"
72
+
73
+ # Create the interface
74
+ with gr.Blocks(title="RandomDiffusion", theme=gr.themes.Soft()) as demo:
75
+ gr.Markdown("# RandomDiffusion")
76
+ gr.Markdown("Text-to-Image Diffusion Model")
77
+
78
+ # Model status
79
+ status = gr.Textbox(label="Model Status", value="Initializing...", interactive=False)
80
+
81
+ # Image generation
82
+ with gr.Row():
83
+ with gr.Column():
84
+ prompt = gr.Textbox(label="Enter Prompt", placeholder="a beautiful landscape")
85
+ generate_btn = gr.Button("Generate")
86
+ with gr.Column():
87
+ output_image = gr.Image(label="Generated Image", type="pil")
88
+ result_text = gr.Textbox(label="Result")
89
+
90
+ # Load model on startup
91
+ demo.load(
92
+ lambda: initialize_model(),
93
+ inputs=[],
94
+ outputs=[status]
95
+ )
96
+
97
+ # Generate on button click
98
+ generate_btn.click(
99
+ generate_image,
100
+ inputs=[prompt],
101
+ outputs=[output_image, result_text]
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ demo.launch(share=True)