callum-canavan's picture
Update app filename
45c0347
raw
history blame contribute delete
No virus
2.02 kB
from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
import gradio as gr
import torch
import numpy as np
stage_1 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-M-v1.0", variant="fp16", torch_dtype=torch.float16
)
stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_1.enable_model_cpu_offload()
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-M-v1.0",
text_encoder=None,
variant="fp16",
torch_dtype=torch.float16,
)
stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_2.enable_model_cpu_offload()
# stage 3
safety_modules = {
"feature_extractor": stage_1.feature_extractor,
"safety_checker": stage_1.safety_checker,
"watermarker": stage_1.watermarker,
}
stage_3 = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler",
**safety_modules,
torch_dtype=torch.float16
)
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_3.enable_model_cpu_offload()
def predict(prompt):
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
generator = torch.manual_seed(0)
image = stage_1(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images
image = stage_2(
image=image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images
image = stage_3(
prompt=prompt, image=image, generator=generator, noise_level=100
).images[0]
return image
gradio_app = gr.Interface(
fn=predict,
inputs="text",
outputs="image",
title="Text to Image Generator",
description="Enter a text string to generate an image.",
)
if __name__ == "__main__":
gradio_app.launch(server_name="0.0.0.0") # server_name="0.0.0.0"