|
import base64 |
|
import hashlib |
|
import json |
|
import time |
|
import gradio as gr |
|
import os |
|
import queue |
|
import zstandard as zstd |
|
|
|
q = queue.Queue() |
|
|
|
""" |
|
Format: |
|
{ |
|
"status": "queue" | "progress" | "done" |
|
"prompt": "Some ducks..." |
|
"id": "abc123" |
|
} |
|
""" |
|
models = [] |
|
|
|
WORKERS = [] |
|
|
|
def enqueue(prompt: str): |
|
tm = time.time() |
|
hsh = hashlib.sha256(prompt.encode("utf-8")).hexdigest() |
|
md = { |
|
"status": "queue", |
|
"prompt": prompt, |
|
"id": f"{hsh}.{tm}" |
|
} |
|
models.append(md) |
|
q.put(json.dumps(md)) |
|
|
|
def dequeue(): |
|
if not q.empty(): |
|
pr = json.loads(q.get_nowait()) |
|
return json.dumps({ |
|
"status": "ok", |
|
"prompt": pr["prompt"], |
|
"id": pr["id"] |
|
}) |
|
return json.dumps({ |
|
"status": "empty" |
|
}) |
|
|
|
def complete(data): |
|
jsn = json.loads(data) |
|
for i in range(len(models)): |
|
if models[i]["id"] == jsn["_id"]: |
|
models[i]["status"] = "done" |
|
for fl in jsn["files"]: |
|
rd = zstd.decompress(base64.b64decode(fl["data"])) |
|
os.makedirs(f"files/{fl['path']}", exist_ok=True) |
|
os.rmdir(f"files/{fl['path']}") |
|
with open(f"files/{fl['path']}", "wb") as f: |
|
f.write(rd) |
|
f.flush() |
|
f.close() |
|
break |
|
return json.dumps({"status": "ok"}) |
|
|
|
class ValueClass(): |
|
def __init__(self, value): |
|
self.value = value |
|
|
|
def worker(): |
|
while True: |
|
if not q.empty(): |
|
pr = json.loads(q.get_nowait()) |
|
pr["status"] = "progress" |
|
for w in WORKERS: |
|
if w["status"] == "idle": |
|
w["status"] = "busy" |
|
w["prompt"] = pr["prompt"] |
|
w["id"] = pr["id"] |
|
break |
|
else: |
|
q.put(json.dumps(pr)) |
|
time.sleep(1) |
|
|
|
with gr.Blocks() as bl: |
|
with gr.Row("none"): |
|
shad_out = gr.Textbox(visible=False) |
|
shad_dequeue = gr.Button(value="Dequeue", visible=False) |
|
shad_in = gr.Textbox(visible=False) |
|
shad_complete = gr.Button(value="Complete", visible=False) |
|
shad_submitted = gr.Checkbox(visible=False) |
|
with gr.Row("panel"): |
|
gr.Label("Enter a prompt to generate an image. This is a work-in-progress. Release 1.") |
|
with gr.Row("panel"): |
|
prompt_input = gr.Textbox(placeholder="Enter a prompt here", visible=shad_complete.value) |
|
with gr.Row("panel"): |
|
submit_button = gr.Button(value="Generate", visible=True) |
|
with gr.Row("panel"): |
|
image_output = gr.Image(label="Generated Image") |
|
submit_button.click(enqueue, inputs=[prompt_input], api_name="enqueue") |
|
shad_dequeue.click(dequeue, outputs=[shad_out], api_name="dequeue") |
|
shad_complete.click(complete, inputs=[shad_in], outputs=[shad_out], api_name="complete") |
|
|
|
def kill_the_noise_skrillex(btn): |
|
prompt_input.visible = False |
|
prompt_input.value = "" |
|
submit_button.visible = False |
|
return btn |
|
|
|
prompt_input = gr.Textbox(label="Prompt") |
|
submit_button = gr.Button("Submit") |
|
|
|
demo = gr.Interface(fn=kill_the_noise_skrillex, |
|
inputs=[submit_button], |
|
outputs=[submit_button]) |
|
|
|
demo.launch() |
|
|