Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
auth_token = "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" | |
model_id = "CompVis/stable-diffusion-v1-4" | |
device = "cuda" | |
"""#載入StableDiffusion model""" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id,revision="fp16", | |
torch_dtype=torch.float16,use_auth_token=auth_token,).to(device) | |
"""#文字推論圖像""" | |
def infer(prompt): | |
generator = torch.Generator(device=device) | |
latents = None | |
seeds = [] | |
width = 512 #int(width_images) | |
height = 512 #int(height_images) | |
num_images = 1 #int(num_images) | |
num_inference_steps = 100 #int(num_inference_steps) | |
seed_number = -1 #int(seed_number) | |
guidance_scale_value = 10 | |
images = [] | |
for _ in range(num_images): | |
# Get a new random seed, store it and use it as the generator state | |
if seed_number < 0: | |
seed = generator.seed() | |
else: | |
seed = seed_number | |
print('seed=' + str(seed)) | |
seeds.append(seed) | |
generator = generator.manual_seed(seed) | |
image_latents = torch.randn( | |
(1, pipe.unet.in_channels, height // 8, width // 8), | |
generator = generator, | |
device = device | |
) | |
latents = image_latents if latents is None else torch.cat((latents, image_latents)) | |
for latent in latents: | |
with torch.autocast('cuda'): | |
image = pipe( | |
[prompt], | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale_value, | |
num_inference_steps=num_inference_steps, | |
latents = latent.unsqueeze(dim=0) | |
)['sample'] | |
images.append(image[0]) | |
return images | |
"""#中文推論圖像""" | |
import requests | |
def infer2(prompt): | |
url = 'https://script.google.com/macros/s/AKfycbyS9bk2G2jvJXzyFUN--nj0Lr8Zi8x_jdJOWnh_dRInxd8uko5KEZxwtoG-WcWdAnpa/exec' | |
url = url + "?test=" + prompt | |
r = requests.get(url) | |
prompt = r.text | |
device = "cuda" | |
generator = torch.Generator(device=device) | |
latents = None | |
seeds = [] | |
width = 512 #int(width_images) | |
height = 512 #int(height_images) | |
num_images = 1 #int(num_images) | |
num_inference_steps = 100 #int(num_inference_steps) | |
seed_number = -1 #int(seed_number) | |
guidance_scale_value = 10 | |
images = [] | |
for _ in range(num_images): | |
# Get a new random seed, store it and use it as the generator state | |
if seed_number < 0: | |
seed = generator.seed() | |
else: | |
seed = seed_number | |
print('seed=' + str(seed)) | |
seeds.append(seed) | |
generator = generator.manual_seed(seed) | |
image_latents = torch.randn( | |
(1, pipe.unet.in_channels, height // 8, width // 8), | |
generator = generator, | |
device = device | |
) | |
latents = image_latents if latents is None else torch.cat((latents, image_latents)) | |
for latent in latents: | |
with torch.autocast('cuda'): | |
image = pipe( | |
[prompt], | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale_value, | |
num_inference_steps=num_inference_steps, | |
latents = latent.unsqueeze(dim=0) | |
)['sample'] | |
images.append(image[0]) | |
return images | |
"""#使用 gradio web api""" | |
#from IPython.display import clear_output | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>文字描述變插圖</center></h1>") | |
text = gr.Textbox(label='輸入描述') | |
btn = gr.Button("轉換") | |
gallery = gr.Gallery(label="產生圖片", show_label=False).style(grid=[2], height="auto") | |
btn.click(infer,inputs=text, outputs=gallery) #infer2 中文 | |
#clear_output() | |
demo.launch(share=True) | |