fast-stable-diffusion / inference.py
zenafey's picture
Update inference.py
6e4bf3d
raw
history blame
No virus
2.6 kB
from prodiapy import Custom
from prodiapy.util import load
from PIL import Image
from threading import Thread
from utils import image_to_base64
import gradio as gr
import gradio_user_history as gr_user_history
import os
pipe = Custom(os.getenv("PRODIA_API_KEY"))
def txt2img(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed, batch_count, profile: gr.OAuthProfile | None):
total_images = []
threads = []
def generate_one_image():
result = pipe.create(
"/sd/generate",
prompt=prompt,
negative_prompt=negative_prompt,
model=model,
steps=steps,
cfg_scale=cfg_scale,
sampler=sampler,
width=width,
height=height,
seed=seed
)
job = pipe.wait_for(result)
total_images.append(job['imageUrl'])
for x in range(batch_count):
t = Thread(target=generate_one_image)
threads.append(t)
t.start()
for t in threads:
t.join()
for image in total_images:
gr_user_history.save_image(label=prompt, image=Image.open(load(image)), profile=profile)
return gr.update(value=total_images, preview=False)
def img2img(input_image, denoising, prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed,
batch_count):
if input_image is None:
return
total_images = []
threads = []
def generate_one_image():
result = pipe.create(
"/sd/transform",
imageData=image_to_base64(input_image),
denoising_strength=denoising,
prompt=prompt,
negative_prompt=negative_prompt,
model=model,
steps=steps,
cfg_scale=cfg_scale,
sampler=sampler,
width=width,
height=height,
seed=seed
)
job = pipe.wait_for(result)
total_images.append(job['imageUrl'])
for x in range(batch_count):
t = Thread(target=generate_one_image)
threads.append(t)
t.start()
for t in threads:
t.join()
return gr.update(value=total_images, preview=False)
def upscale(image, scale, profile: gr.OAuthProfile | None):
if image is None:
return
job = pipe.create(
'/upscale',
imageData=image_to_base64(image),
resize=scale
)
image = pipe.wait_for(job)['imageUrl']
gr_user_history.save_image(label=f'upscale by {scale}', image=Image.open(load(image)), profile=profile)
return image