|
|
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
import random |
|
from typing import Tuple, Optional |
|
|
|
import gradio as gr |
|
from huggingface_hub import HfApi |
|
|
|
from inf import InferencePipeline |
|
|
|
SAMPLE_MODEL_IDS = [ |
|
'lora-library/B-LoRA-teddybear', |
|
'lora-library/B-LoRA-bull', |
|
'lora-library/B-LoRA-wolf_plushie', |
|
'lora-library/B-LoRA-pen_sketch', |
|
'lora-library/B-LoRA-cartoon_line', |
|
'lora-library/B-LoRA-multi-dog2', |
|
] |
|
css = """ |
|
body { |
|
font-size: 30px; |
|
} |
|
.gr-image { |
|
width: 512px; |
|
height: 512px; |
|
object-fit: contain; |
|
margin: auto; |
|
} |
|
|
|
.lora-column { |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; /* Center align content vertically in columns */ |
|
justify-content: center; /* Center content horizontally in columns */ |
|
} |
|
.gr-row { |
|
align-items: center; |
|
justify-content: center; |
|
margin-top: 5px; |
|
} |
|
""" |
|
|
|
|
|
def get_choices(hf_token): |
|
api = HfApi(token=hf_token) |
|
choices = [ |
|
info.modelId for info in api.list_models(author='lora-library') |
|
] |
|
models_list = ['None'] + SAMPLE_MODEL_IDS + choices |
|
return models_list |
|
|
|
|
|
def get_image_from_card(card, model_id) -> Optional[str]: |
|
try: |
|
card_path = f"https://huggingface.co/{model_id}/resolve/main/" |
|
widget = card.data.get('widget') |
|
if widget is not None or len(widget) > 0: |
|
output = widget[0].get('output') |
|
if output is not None: |
|
url = output.get('url') |
|
if url is not None: |
|
return card_path + url |
|
return None |
|
except Exception: |
|
return None |
|
|
|
|
|
def demo_init(): |
|
try: |
|
choices = get_choices(app.hf_token) |
|
content_blora = random.choice(SAMPLE_MODEL_IDS) |
|
style_blora = random.choice(SAMPLE_MODEL_IDS) |
|
content_blora_prompt, content_blora_image = app.load_model_info(content_blora) |
|
style_blora_prompt, style_blora_image = app.load_model_info(style_blora) |
|
|
|
content_lora_model_id = gr.update(choices=choices, value=content_blora) |
|
content_prompt = gr.update(value=content_blora_prompt) |
|
content_image = gr.update(value=content_blora_image) |
|
|
|
style_lora_model_id = gr.update(choices=choices, value=style_blora) |
|
style_prompt = gr.update(value=style_blora_prompt) |
|
style_image = gr.update(value=style_blora_image) |
|
|
|
prompt = gr.update( |
|
value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style') |
|
|
|
return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt |
|
|
|
except Exception as e: |
|
raise type(e)(f'failed to demo_init, due to: {e}') |
|
|
|
|
|
def toggle_column(is_checked): |
|
try: |
|
return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS) |
|
except Exception as e: |
|
raise type(e)(f'failed to toggle_column, due to: {e}') |
|
|
|
|
|
class InferenceUtil: |
|
def __init__(self, hf_token: str | None): |
|
self.hf_token = hf_token |
|
|
|
def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]: |
|
try: |
|
try: |
|
card = InferencePipeline.get_model_card(lora_model_id, |
|
self.hf_token) |
|
except Exception: |
|
return '', None |
|
instance_prompt = getattr(card.data, 'instance_prompt', '') |
|
image_url = get_image_from_card(card, lora_model_id) |
|
return instance_prompt, image_url |
|
except Exception as e: |
|
raise type(e)(f'failed to load_model_info, due to: {e}') |
|
|
|
def update_model_info(self, model_source: str): |
|
try: |
|
if model_source == 'None': |
|
return '', None |
|
else: |
|
model_info = self.load_model_info(model_source) |
|
new_prompt, new_image = model_info[0], model_info[1] |
|
return new_prompt, new_image |
|
except Exception as e: |
|
raise type(e)(f'failed to update_model_info, due to: {e}') |
|
|
|
|
|
def create_inference_demo(pipe, |
|
hf_token: str | None = None) -> gr.Blocks: |
|
with gr.Blocks(css=css) as demo: |
|
with gr.Row(elem_classes="gr-row"): |
|
with gr.Column(): |
|
with gr.Group(elem_classes="lora-column"): |
|
gr.Markdown('## Content B-LoRA') |
|
content_checkbox = gr.Checkbox(label='Use Content Only', value=False) |
|
content_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) |
|
content_prompt = gr.Text(label='Content instance prompt', interactive=False, max_lines=1) |
|
content_image = gr.Image(label='Content Image', elem_classes="gr-image") |
|
with gr.Column(): |
|
with gr.Group(elem_classes="lora-column"): |
|
gr.Markdown('## Style B-LoRA') |
|
style_checkbox = gr.Checkbox(label='Use Style Only', value=False) |
|
style_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) |
|
style_prompt = gr.Text(label='Style instance prompt', interactive=False, max_lines=1) |
|
style_image = gr.Image(label='Style Image', elem_classes="gr-image") |
|
with gr.Row(elem_classes="gr-row"): |
|
with gr.Column(): |
|
with gr.Group(): |
|
prompt = gr.Textbox( |
|
label='Prompt', |
|
max_lines=1, |
|
placeholder='Example: "A [c] in [s] style"' |
|
) |
|
result = gr.Image(label='Result') |
|
with gr.Accordion('Other Parameters', open=False, elem_classes="gr-accordion"): |
|
content_alpha = gr.Slider(label='Content B-LoRA alpha', |
|
minimum=0, |
|
maximum=2, |
|
step=0.05, |
|
value=1) |
|
style_alpha = gr.Slider(label='Style B-LoRA alpha', |
|
minimum=0, |
|
maximum=2, |
|
step=0.05, |
|
value=1) |
|
seed = gr.Slider(label='Seed', |
|
minimum=0, |
|
maximum=100000, |
|
step=1, |
|
value=8888) |
|
num_steps = gr.Slider(label='Number of Steps', |
|
minimum=0, |
|
maximum=100, |
|
step=1, |
|
value=50) |
|
guidance_scale = gr.Slider(label='CFG Scale', |
|
minimum=0, |
|
maximum=50, |
|
step=0.1, |
|
value=7.5) |
|
|
|
run_button = gr.Button('Generate') |
|
demo.load(demo_init, inputs=[], |
|
outputs=[content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, |
|
style_image, prompt], queue=False, show_progress="hidden") |
|
content_lora_model_id.change( |
|
fn=app.update_model_info, |
|
inputs=content_lora_model_id, |
|
outputs=[ |
|
content_prompt, |
|
content_image, |
|
]) |
|
style_lora_model_id.change( |
|
fn=app.update_model_info, |
|
inputs=style_lora_model_id, |
|
outputs=[ |
|
style_prompt, |
|
style_image, |
|
]) |
|
style_prompt.change( |
|
fn=lambda content_blora_prompt, |
|
style_blora_prompt: f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if style_blora_prompt else content_blora_prompt, |
|
inputs=[content_prompt, style_prompt], |
|
outputs=prompt, |
|
) |
|
content_prompt.change( |
|
fn=lambda content_blora_prompt, |
|
style_blora_prompt: f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if content_blora_prompt else style_blora_prompt, |
|
inputs=[content_prompt, style_prompt], |
|
outputs=prompt, |
|
) |
|
content_checkbox.change(toggle_column, inputs=[content_checkbox], |
|
outputs=[style_lora_model_id]) |
|
style_checkbox.change(toggle_column, inputs=[style_checkbox], |
|
outputs=[content_lora_model_id]) |
|
inputs = [ |
|
content_lora_model_id, |
|
style_lora_model_id, |
|
prompt, |
|
content_alpha, |
|
style_alpha, |
|
seed, |
|
num_steps, |
|
guidance_scale, |
|
] |
|
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result) |
|
run_button.click(fn=pipe.run, inputs=inputs, outputs=result) |
|
return demo |
|
|
|
|
|
if __name__ == '__main__': |
|
hf_token = os.getenv('HF_TOKEN') |
|
pipe = InferencePipeline(hf_token) |
|
app = InferenceUtil(hf_token) |
|
demo = create_inference_demo(pipe, hf_token) |
|
demo.queue(max_size=10).launch(share=False) |
|
|