|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import BertTokenizer, BertModel |
|
import numpy as np |
|
import os |
|
import time |
|
from typing import Optional, Union |
|
|
|
LATENT_DIM = 128 |
|
HIDDEN_DIM = 256 |
|
|
|
|
|
class TextEncoder(nn.Module): |
|
def __init__(self, hidden_size, output_size): |
|
super(TextEncoder, self).__init__() |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.fc = nn.Linear(self.bert.config.hidden_size, output_size) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
return self.fc(outputs.last_hidden_state[:, 0, :]) |
|
|
|
|
|
class CVAE(nn.Module): |
|
def __init__(self, text_encoder): |
|
super(CVAE, self).__init__() |
|
self.text_encoder = text_encoder |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
nn.Conv2d(4, 32, 3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 64, 3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 128, 3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.Flatten(), |
|
nn.Linear(128 * 4 * 4, HIDDEN_DIM) |
|
) |
|
|
|
self.fc_mu = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
|
self.fc_logvar = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
|
|
|
|
|
self.decoder_input = nn.Linear(LATENT_DIM + HIDDEN_DIM, 128 * 4 * 4) |
|
self.decoder = nn.Sequential( |
|
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), |
|
nn.ReLU(), |
|
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 4, 3, stride=1, padding=1), |
|
nn.Tanh() |
|
) |
|
|
|
def encode(self, x, c): |
|
x = self.encoder(x) |
|
x = torch.cat([x, c], dim=1) |
|
mu = self.fc_mu(x) |
|
logvar = self.fc_logvar(x) |
|
return mu, logvar |
|
|
|
def decode(self, z, c): |
|
z = torch.cat([z, c], dim=1) |
|
x = self.decoder_input(z) |
|
x = x.view(-1, 128, 4, 4) |
|
return self.decoder(x) |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return mu + eps * std |
|
|
|
def forward(self, x, c): |
|
mu, logvar = self.encode(x, c) |
|
z = self.reparameterize(mu, logvar) |
|
return self.decode(z, c), mu, logvar |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def clean_image(image: Image.Image, threshold: float = 0.75) -> Image.Image: |
|
np_image = np.array(image) |
|
alpha_channel = np_image[:, :, 3] |
|
alpha_channel[alpha_channel <= int(threshold * 255)] = 0 |
|
alpha_channel[alpha_channel > int(threshold * 255)] = 255 |
|
return Image.fromarray(np_image) |
|
|
|
def generate_image( |
|
model: CVAE, |
|
text_prompt: str, |
|
device: torch.device, |
|
input_image: Optional[Image.Image] = None, |
|
img_control: float = 0.5 |
|
) -> Image.Image: |
|
encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt") |
|
input_ids = encoded_input['input_ids'].to(device) |
|
attention_mask = encoded_input['attention_mask'].to(device) |
|
|
|
with torch.no_grad(): |
|
text_encoding = model.text_encoder(input_ids, attention_mask) |
|
z = torch.randn(1, LATENT_DIM).to(device) |
|
generated_image = model.decode(z, text_encoding) |
|
|
|
if input_image is not None: |
|
input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST) |
|
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device) |
|
generated_image = img_control * input_image + (1 - img_control) * generated_image |
|
|
|
generated_image = generated_image.squeeze(0).cpu() |
|
generated_image = (generated_image + 1) / 2 |
|
generated_image = generated_image.clamp(0, 1) |
|
generated_image = transforms.ToPILImage()(generated_image) |
|
|
|
return generated_image |
|
|
|
|
|
_model_cache = {} |
|
def load_model(model_path: str, device: torch.device) -> CVAE: |
|
if model_path not in _model_cache: |
|
text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM) |
|
model = CVAE(text_encoder).to(device) |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.eval() |
|
_model_cache[model_path] = model |
|
return _model_cache[model_path] |
|
|
|
def generate_image_gradio( |
|
prompt: str, |
|
model_path: str, |
|
clean_image_flag: bool, |
|
size: int, |
|
input_image: Optional[Image.Image] = None, |
|
img_control: float = 0.5 |
|
) -> tuple[Image.Image, str]: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
try: |
|
model = load_model(model_path, device) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load model: {str(e)}") |
|
|
|
start_time = time.time() |
|
try: |
|
generated_image = generate_image(model, prompt, device, input_image, img_control) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to generate image: {str(e)}") |
|
|
|
end_time = time.time() |
|
generation_time = end_time - start_time |
|
|
|
if clean_image_flag: |
|
generated_image = clean_image(generated_image) |
|
|
|
try: |
|
generated_image = generated_image.resize((size, size), resample=Image.NEAREST) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to resize image: {str(e)}") |
|
|
|
return generated_image, f"Generation time: {generation_time:.4f} seconds" |
|
|
|
def gradio_interface() -> gr.Blocks: |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Image Generator from Text Prompt") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Text Prompt") |
|
model_path = gr.Textbox(label="Model Path", value="BitRoss.pth") |
|
clean_image_flag = gr.Checkbox(label="Clean Image", value=False) |
|
size = gr.Slider(minimum=16, maximum=1024, step=16, label="Image Size", value=16) |
|
img_control = gr.Slider(minimum=0, maximum=1, step=0.1, label="Image Control", value=0.5) |
|
input_image = gr.Image(label="Input Image (optional)", type="pil") |
|
generate_button = gr.Button("Generate Image") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image") |
|
generation_time = gr.Textbox(label="Generation Time") |
|
|
|
|
|
generate_button.click( |
|
fn=generate_image_gradio, |
|
inputs=[prompt, model_path, clean_image_flag, size, input_image, img_control], |
|
outputs=[output_image, generation_time], |
|
api_name="generate" |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = gradio_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
|
|
|
|
|
|
) |