Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import os | |
import spaces | |
# os.environ['CUDA_VISIBLE_DEVICES'] = '7' | |
from OmniGen import OmniGenPipeline | |
pipe = OmniGenPipeline.from_pretrained("shitao/tmp-preview") | |
pipe.to("cuda") | |
# 示例处理函数:生成图像 | |
def generate_image(text, img1, img2, img3, height, width, guidance_scale): | |
input_images = [img1, img2, img3] | |
# 去除 None | |
input_images = [img for img in input_images if img is not None] | |
if len(input_images) == 0: | |
input_images = None | |
output = pipe( | |
prompt=text, | |
input_images=input_images, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
img_guidance_scale=1.6, | |
separate_cfg_infer=True, | |
use_kv_cache=False | |
) | |
img = output[0] | |
return img | |
# Gradio 接口 | |
with gr.Blocks() as demo: | |
gr.Markdown("## Text + Multiple Images to Image Generator") | |
with gr.Row(): | |
with gr.Column(): | |
# 文本输入框 | |
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here...") | |
# 图片上传框 | |
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath") | |
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath") | |
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath") | |
# 高度和宽度滑块 | |
height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=1024, step=16) | |
width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1024, step=16) | |
# 引导尺度输入 | |
guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1) | |
# 生成按钮 | |
generate_button = gr.Button("Generate Image") | |
with gr.Column(): | |
# 输出图像框 | |
output_image = gr.Image(label="Output Image") | |
# 按钮点击事件 | |
generate_button.click( | |
generate_image, | |
inputs=[prompt_input, image_input_1, image_input_2, image_input_3, height_input, width_input, guidance_scale_input], | |
outputs=output_image | |
) | |
# 启动应用 | |
demo.launch() |