import os import numpy as np from prodiapy import Prodia import gradio as gr import json import requests import base64 import random import time STYLE_PRESETS = [None, "3d-model", "analog-film", "anime", "cinematic", "comic-book", "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", "neon-punk", "origami", "photographic", "pixel-art", "texture", "craft-clay"] MAX_SEED = np.iinfo(np.int32).max class Prodia: def __init__(self, api_key=os.getenv("PRODIA_API_KEY"), base=None): self.base = base or "https://api.prodia.com/v1" self.headers = { "X-Prodia-Key": api_key } def photomaker(self, params): print(params) response = self._post(f"{self.base}/photomaker", params) return response.json() def get_job(self, job_id): response = self._get(f"{self.base}/job/{job_id}") return response.json() def wait(self, job): job_result = job while job_result['status'] not in ['succeeded', 'failed']: time.sleep(0.25) job_result = self.get_job(job['job']) return job_result def _post(self, url, params): headers = { **self.headers, "Content-Type": "application/json" } response = requests.post(url, headers=headers, data=json.dumps(params)) if response.status_code != 200: raise Exception(f"Bad Prodia Response: {response.status_code}") return response def _get(self, url): response = requests.get(url, headers=self.headers) if response.status_code != 200: raise Exception(f"Bad Prodia Response: {response.status_code}") return response client = Prodia() def generate_image(upload_images, prompt, negative_prompt, style_preset, steps, cfg_scale, strength, seed, progress=gr.Progress(track_tqdm=True)): error_if_no_img(prompt) params = { "imageData": [file_to_base64(img) for img in upload_images], "prompt": prompt, "negative_prompt": negative_prompt, "steps": steps, "cfg_scale": cfg_scale, "strength": strength, "seed": seed if seed != 0 else random.randint(1, MAX_SEED) } test_params = { "imageUrls": ["https://cdn.discordapp.com/attachments/1168595302857584651/1241095243919786084/033114_fb_elon.png?ex=6648f384&is=6647a204&hm=8a6e9a50ff027e308bbb5d496b0f0476507f6476e7c819d975f15dcec153506c&"], "prompt": prompt, "negative_prompt": negative_prompt, "steps": steps, "cfg_scale": cfg_scale, "strength": strength, "seed": seed if seed != 0 else random.randint(1, MAX_SEED) } if style_preset is not None and style_preset in STYLE_PRESETS: params['style_preset'] = style_preset job = client.photomaker(params) res = client.wait(job) if res['status'] == "failed": return return res['imageUrl'] def error_if_no_img(prompt): if "img" not in prompt: raise gr.Error("Prompt must contain 'img'") def swap_to_gallery(images): return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) def upload_example_to_gallery(images, prompt, style, negative_prompt): return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) def remove_back_to_files(): return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) def get_image_path_list(folder_name): image_basename_list = os.listdir(folder_name) image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) return image_path_list def file_to_base64(file_path): with open(file_path, "rb") as file: file_data = file.read() base64_string = base64.b64encode(file_data).decode('utf-8') return base64_string def get_example(): case = [ [ get_image_path_list('./examples/scarletthead_woman'), "instagram photo, portrait photo of a woman img, colorful, perfect face, natural skin, hard shadows, film grain", None, "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", ], [ get_image_path_list('./examples/newton_man'), "sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain", None, "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", ], ] return case title = r"""

PhotoMaker: Generate images with facial consistency to input images

""" css = ''' .gradio-container {width: 85% !important} ''' with gr.Blocks(css=css) as demo: gr.Markdown(title) with gr.Row(): with gr.Column(): files = gr.File( label="Drag (Select) 1 or more photos of your face", file_types=["image"], file_count="multiple" ) uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200) with gr.Column(visible=False) as clear_button: remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm") prompt = gr.Textbox(label="Prompt", info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.", placeholder="A photo of a [man/woman img]...") style = gr.Dropdown(label="Style template", choices=STYLE_PRESETS, value=None) submit = gr.Button("Submit") with gr.Accordion(open=False, label="Advanced Options"): negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="low quality", value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry", ) with gr.Row(): steps = gr.Slider( label="Number of sample steps", minimum=20, maximum=50, step=1, value=40, ) cfg_scale = gr.Slider( label="CFG Scale", minimum=5, maximum=20, value=7, ) with gr.Row(): strength_ratio = gr.Slider( label="Strength (%)", minimum=15, maximum=50, step=1, value=20, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) with gr.Column(): result_image = gr.Image(label="Generated Image") files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files]) remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files]) submit.click( fn=generate_image, inputs=[files, prompt, negative_prompt, style, steps, cfg_scale, strength_ratio, seed], outputs=[result_image] ) gr.Examples( examples=get_example(), inputs=[files, prompt, style, negative_prompt], run_on_click=True, fn=upload_example_to_gallery, outputs=[uploaded_files, clear_button, files], ) demo.queue(max_size=20).launch(show_api=False)