Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import random | |
import uuid | |
import gradio as gr | |
import spaces | |
import torch | |
from diffusers import DiffusionPipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
device = torch.device("cuda:0") | |
llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5-9b", torch_dtype=torch.float16, device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5-9b") | |
diffusion_pipe = DiffusionPipeline.from_pretrained( | |
"playgroundai/playground-v2.5-1024px-aesthetic", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
add_watermarker=False, | |
variant="fp16" | |
).to(device) | |
def get_input_ids(inst, bot_prefix): | |
return tokenizer.encode("A chat between a human and an artificial intelligence bot. " | |
"The bot gives helpful, detailed, and polite answers to the human's questions.\n" | |
f"|Human|: {inst}\n|Bot|: {bot_prefix}", add_special_tokens=True) | |
def save_image(img): | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。 | |
[输出]是一个json,包含description、en_description、expanded_description三个字符串字段,请直接输出一个完整的json,不要输出任何解释或其他无关内容。 | |
下面是一些示例: | |
[作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。" | |
[输出]->{"description": "落霞与孤鹜齐飞,秋水共长天一色", "en_description": "The setting sun and the solitary duck fly together, the autumn water shares a single hue with the vast sky", "expanded_description": "A lone duck gracefully gliding across the tranquil surface of a shimmering lake, bathed in the warm golden glow of the setting sun, creating a breathtaking scene of natural beauty and tranquility."} | |
[作画要求]->"原神中的可莉" | |
[输出]->{"description": "原神中的可莉", "en_description": "Klee in Genshin Impact", "expanded_description": "An artistic portrait of Klee from Genshin Impact, standing in a vibrant meadow with colorful explosions of her elemental abilities in the background."} | |
[作画要求]->"create an image for me. a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone" | |
[输出]->{"description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "en_description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "expanded_description": "A close-up portrait of an elegant woman with rich brown skin, wearing a stunning transparent, prismatic, and intricately detailed Nemes headdress, striking a confident and alluring over-the-shoulder pose."} | |
[作画要求]->"一只高贵的柯基犬,素描画风格\n根据上面的描述生成一张图片吧!" | |
[输出]->{"description": "一只高贵的柯基犬,素描画风格", "en_description": "A noble corgi dog, sketch style", "expanded_description": "A majestic corgi with a regal bearing, depicted in a detailed and intricate pencil sketch, capturing the essence of its noble lineage and dignified presence."} | |
[作画要求]->$USER_PROMPT | |
[输出]->''' | |
BOT_PREFIX = '{"description": "' | |
def generate( | |
prompt: str, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX) | |
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True, | |
max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05) | |
llm_result = llm.generate(**generation_kwargs) | |
llm_result = llm_result.cpu()[0][len(input_ids):] | |
llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True) | |
print("----------") | |
print(prompt) | |
print(llm_result) | |
en_prompt = prompt | |
expanded_prompt = prompt | |
try: | |
en_prompt = json.loads(llm_result)["en_description"] | |
expanded_prompt = json.loads(llm_result)["expanded_description"] | |
except: | |
print("error, fallback to original prompt") | |
pass | |
seed = random.randint(0, 2147483647) | |
generator = torch.Generator().manual_seed(seed) | |
images = diffusion_pipe( | |
prompt=[en_prompt, expanded_prompt], | |
negative_prompt=None, | |
width=1024, | |
height=1024, | |
guidance_scale=3, | |
num_inference_steps=25, | |
generator=generator, | |
num_images_per_prompt=1, | |
use_resolution_binning=True, | |
output_type="pil", | |
).images | |
image_paths = [save_image(img) for img in images] | |
return image_paths | |
css = ''' | |
.gradio-container{max-width: 560px !important} | |
h1{text-align:center} | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Blossom Playground v2.5") | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Gallery(label="Result", columns=2, rows=1, show_label=False) | |
gr.on( | |
triggers=[ | |
prompt.submit, | |
run_button.click, | |
], | |
fn=generate, | |
inputs=[ | |
prompt, | |
], | |
outputs=[result], | |
api_name="run", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |