Spaces:
Running
Running
import os | |
import numpy as np | |
from prodiapy import Prodia | |
import gradio as gr | |
import json | |
import requests | |
import base64 | |
import random | |
import time | |
STYLE_PRESETS = [None, "3d-model", "analog-film", "anime", "cinematic", "comic-book", "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", | |
"neon-punk", "origami", "photographic", "pixel-art", "texture", "craft-clay"] | |
MAX_SEED = np.iinfo(np.int32).max | |
class Prodia: | |
def __init__(self, api_key=os.getenv("PRODIA_API_KEY"), base=None): | |
self.base = base or "https://api.prodia.com/v1" | |
self.headers = { | |
"X-Prodia-Key": api_key | |
} | |
def photomaker(self, params): | |
print(params) | |
response = self._post(f"{self.base}/photomaker", params) | |
return response.json() | |
def get_job(self, job_id): | |
response = self._get(f"{self.base}/job/{job_id}") | |
return response.json() | |
def wait(self, job): | |
job_result = job | |
while job_result['status'] not in ['succeeded', 'failed']: | |
time.sleep(0.25) | |
job_result = self.get_job(job['job']) | |
return job_result | |
def _post(self, url, params): | |
headers = { | |
**self.headers, | |
"Content-Type": "application/json" | |
} | |
response = requests.post(url, headers=headers, data=json.dumps(params)) | |
if response.status_code != 200: | |
raise Exception(f"Bad Prodia Response: {response.status_code}") | |
return response | |
def _get(self, url): | |
response = requests.get(url, headers=self.headers) | |
if response.status_code != 200: | |
raise Exception(f"Bad Prodia Response: {response.status_code}") | |
return response | |
client = Prodia() | |
def generate_image(upload_images, prompt, negative_prompt, style_preset, steps, cfg_scale, strength, seed, progress=gr.Progress(track_tqdm=True)): | |
error_if_no_img(prompt) | |
print(upload_images) | |
params = { | |
"imageData": [file_to_base64(img) for img in upload_images], | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"steps": steps, | |
"cfg_scale": cfg_scale, | |
"strength": strength, | |
"seed": seed if seed != 0 else random.randint(1, MAX_SEED) | |
} | |
if style_preset is not None and style_preset in STYLE_PRESETS: | |
params['style_preset'] = style_preset | |
job = client.photomaker(params) | |
res = client.wait(job) | |
if res['status'] == "failed": | |
return | |
return res['imageUrl'] | |
def error_if_no_img(prompt): | |
if "img" not in prompt: | |
raise gr.Error("Prompt must contain 'img'") | |
def swap_to_gallery(images): | |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
def upload_example_to_gallery(images, prompt, style, negative_prompt): | |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
def remove_back_to_files(): | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
def get_image_path_list(folder_name): | |
image_basename_list = os.listdir(folder_name) | |
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) | |
return image_path_list | |
def file_to_base64(file_path): | |
with open(file_path, "rb") as file: | |
file_data = file.read() | |
base64_string = base64.b64encode(file_data).decode('utf-8') | |
return base64_string | |
def get_example(): | |
case = [ | |
[ | |
get_image_path_list('./examples/scarletthead_woman'), | |
"instagram photo, portrait photo of a woman img , colorful, perfect face, natural skin, hard shadows, film grain", | |
None, | |
"(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", | |
], | |
[ | |
get_image_path_list('./examples/newton_man'), | |
"sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain", | |
None, | |
"(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", | |
], | |
] | |
return case | |
title = r""" | |
<h1 align="center">PhotoMaker: Generate images with facial consistency to input images</h1> | |
""" | |
css = ''' | |
.gradio-container {width: 85% !important} | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(title) | |
with gr.Row(): | |
with gr.Column(): | |
files = gr.File( | |
label="Drag (Select) 1 or more photos of your face", | |
file_types=["image"], | |
file_count="multiple" | |
) | |
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200) | |
with gr.Column(visible=False) as clear_button: | |
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm") | |
prompt = gr.Textbox(label="Prompt", | |
info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.", | |
placeholder="A photo of a [man/woman img]...") | |
style = gr.Dropdown(label="Style template", choices=STYLE_PRESETS, value=None) | |
submit = gr.Button("Submit") | |
with gr.Accordion(open=False, label="Advanced Options"): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="low quality", | |
value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry", | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Number of sample steps", | |
minimum=20, | |
maximum=50, | |
step=1, | |
value=40, | |
) | |
cfg_scale = gr.Slider( | |
label="CFG Scale", | |
minimum=5, | |
maximum=20, | |
value=7, | |
) | |
with gr.Row(): | |
strength_ratio = gr.Slider( | |
label="Strength (%)", | |
minimum=15, | |
maximum=50, | |
step=1, | |
value=20, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
with gr.Column(): | |
result_image = gr.Image(label="Generated Image") | |
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files]) | |
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files]) | |
submit.click( | |
fn=generate_image, | |
inputs=[files, prompt, negative_prompt, style, steps, cfg_scale, strength_ratio, seed], | |
outputs=[result_image] | |
) | |
gr.Examples( | |
examples=get_example(), | |
inputs=[files, prompt, style, negative_prompt], | |
run_on_click=True, | |
fn=upload_example_to_gallery, | |
outputs=[uploaded_files, clear_button, files], | |
) | |
demo.queue(max_size=20).launch(show_api=False) | |