Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import open_clip | |
import mediapy as media | |
from optim_utils import * | |
import argparse | |
# load args | |
args = argparse.Namespace() | |
args.__dict__.update(read_json("sample_config.json")) | |
args.print_step = None | |
# load model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, _, preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device) | |
args.counter = 0 | |
def inference(target_image, prompt_len, iter): | |
args.counter += 1 | |
print(args.counter) | |
if prompt_len is not None: | |
args.prompt_len = int(prompt_len) | |
else: | |
args.prompt_len = 8 | |
if iter is not None: | |
args.iter = int(iter) | |
else: | |
args.iter = 1000 | |
learned_prompt = optimize_prompt(model, preprocess, args, device, target_images=[target_image]) | |
return learned_prompt | |
def inference_text(target_prompt, prompt_len, iter): | |
args.counter += 1 | |
print(args.counter) | |
if prompt_len is not None: | |
args.prompt_len = min(int(prompt_len), 75) | |
else: | |
args.prompt_len = 8 | |
if iter is not None: | |
args.iter = min(int(iter), 3000) | |
else: | |
args.iter = 1000 | |
learned_prompt = optimize_prompt(model, preprocess, args, device, target_prompts=[target_prompt]) | |
return learned_prompt | |
gr.Progress(track_tqdm=True) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# PEZ Dispenser") | |
gr.Markdown("## Hard Prompts Made Easy (PEZ)") | |
gr.Markdown("*Want to generate a text prompt for your image that is useful for Stable Diffusion?*") | |
gr.Markdown("This space can either generate a text fragment that describes your image, or it can shorten an existing text prompt. This space is using OpenCLIP-ViT/H, the same text encoder used by Stable Diffusion V2. After you generate a prompt, try it out on Stable Diffusion [here](https://huggingface.co/stabilityai/stable-diffusion-2-1-base), [here](https://huggingface.co/spaces/stabilityai/stable-diffusion) or on [Midjourney](https://docs.midjourney.com/). For a quick PEZ demo, try clicking on one of the examples at the bottom of this page.") | |
gr.Markdown("For additional details, you can check out the [paper](https://arxiv.org/abs/2302.03668) and the code on [Github](https://github.com/YuxinWenRick/hard-prompts-made-easy).") | |
gr.Markdown("Note: Generation with 1000 steps takes ~60 seconds with a T4. Don't want to wait? You can also run on [Google Colab](https://colab.research.google.com/drive/1VSFps4siwASXDwhK_o29dKA9COvTnG8A?usp=sharing). Or, you can reduce the number of steps.") | |
gr.HTML(""" | |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. | |
<br/> | |
<a href="https://huggingface.co/spaces/tomg-group-umd/pez-dispenser?duplicate=true"> | |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> | |
<p/>""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Image to Prompt") | |
input_image = gr.inputs.Image(type="pil", label="Target Image") | |
image_button = gr.Button("Generate Prompt") | |
gr.Markdown("### Long Prompt to Short Prompt") | |
input_prompt = gr.Textbox(label="Target Prompt") | |
prompt_button = gr.Button("Distill Prompt") | |
prompt_len_field = gr.Number(label="Prompt Length (max 75, recommend 8-16)", value=8) | |
num_step_field = gr.Number(label="Optimization Steps (max 3000 because of limited resources)", value=1000) | |
with gr.Column(): | |
gr.Markdown("### Learned Prompt") | |
output_prompt = gr.outputs.Textbox(label="Learned Prompt") | |
image_button.click(inference, inputs=[input_image, prompt_len_field, num_step_field], outputs=output_prompt) | |
prompt_button.click(inference_text, inputs=[input_prompt, prompt_len_field, num_step_field], outputs=output_prompt) | |
gr.Examples([["sample.jpeg", 8, 1000]], inputs=[input_image, prompt_len_field, num_step_field], fn=inference, outputs=output_prompt, cache_examples=True) | |
gr.Examples([["digital concept art of old wooden cabin in florida swamp, trending on artstation", 3, 1000]], inputs=[input_prompt, prompt_len_field, num_step_field], fn=inference_text, outputs=output_prompt, cache_examples=True) | |
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=tomg-group-umd_pez-dispenser)") | |
demo.launch(enable_queue=True) | |