Azure99's picture
Update app.py
772f8c2 verified
import json
import random
import uuid
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda:0")
llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5.1-9b", torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5.1-9b")
diffusion_pipe = DiffusionPipeline.from_pretrained(
"playgroundai/playground-v2.5-1024px-aesthetic",
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
variant="fp16"
).to(device)
def get_input_ids(inst, bot_prefix):
return tokenizer.encode("A chat between a human and an artificial intelligence bot. "
"The bot gives helpful, detailed, and polite answers to the human's questions.\n"
f"|Human|: {inst}\n|Bot|: {bot_prefix}", add_special_tokens=True)
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。
[输出]是一个json,包含description、en_description、expanded_description三个字符串字段,请直接输出一个完整的json,不要输出任何解释或其他无关内容。
下面是一些示例:
[作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。"
[输出]->{"description": "落霞与孤鹜齐飞,秋水共长天一色", "en_description": "The setting sun and the solitary duck fly together, the autumn water shares a single hue with the vast sky", "expanded_description": "A lone duck gracefully gliding across the tranquil surface of a shimmering lake, bathed in the warm golden glow of the setting sun, creating a breathtaking scene of natural beauty and tranquility."}
[作画要求]->"原神中的可莉"
[输出]->{"description": "原神中的可莉", "en_description": "Klee in Genshin Impact", "expanded_description": "An artistic portrait of Klee from Genshin Impact, standing in a vibrant meadow with colorful explosions of her elemental abilities in the background."}
[作画要求]->"create an image for me. a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone"
[输出]->{"description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "en_description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "expanded_description": "A close-up portrait of an elegant woman with rich brown skin, wearing a stunning transparent, prismatic, and intricately detailed Nemes headdress, striking a confident and alluring over-the-shoulder pose."}
[作画要求]->"一只高贵的柯基犬,素描画风格\n根据上面的描述生成一张图片吧!"
[输出]->{"description": "一只高贵的柯基犬,素描画风格", "en_description": "A noble corgi dog, sketch style", "expanded_description": "A majestic corgi with a regal bearing, depicted in a detailed and intricate pencil sketch, capturing the essence of its noble lineage and dignified presence."}
[作画要求]->$USER_PROMPT
[输出]->'''
BOT_PREFIX = '{"description": "'
@spaces.GPU(enable_queue=True)
def generate(
prompt: str,
progress=gr.Progress(track_tqdm=True),
):
input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
llm_result = llm.generate(**generation_kwargs)
llm_result = llm_result.cpu()[0][len(input_ids):]
llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
print("----------")
print(prompt)
print(llm_result)
en_prompt = prompt
expanded_prompt = prompt
try:
en_prompt = json.loads(llm_result)["en_description"]
expanded_prompt = json.loads(llm_result)["expanded_description"]
except:
print("error, fallback to original prompt")
pass
seed = random.randint(0, 2147483647)
generator = torch.Generator().manual_seed(seed)
images = diffusion_pipe(
prompt=[expanded_prompt, en_prompt],
negative_prompt=None,
width=1024,
height=1024,
guidance_scale=3,
num_inference_steps=25,
generator=generator,
num_images_per_prompt=1,
use_resolution_binning=True,
output_type="pil",
).images
image_paths = [save_image(img) for img in images]
return image_paths
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Blossom & Playground v2.5")
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", columns=2, rows=1, show_label=False)
gr.on(
triggers=[
prompt.submit,
run_button.click,
],
fn=generate,
inputs=[
prompt,
],
outputs=[result],
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()