import websocket # websocket-client import uuid import json import urllib.request import urllib.parse import random from PIL import Image import io import base64 import io import os import gradio as gr server_address = os.environ.get("URL_API") json_data=os.environ.get("JSON_API") client_id = str(uuid.uuid4()) def queue_prompt(prompt): p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p, indent=4).encode('utf-8') # Prettify JSON for print req = urllib.request.Request(f"http://{server_address}/prompt", data=data) return json.loads(urllib.request.urlopen(req).read()) def get_image(filename, subfolder, folder_type): data = {"filename": filename, "subfolder": subfolder, "type": folder_type} url_values = urllib.parse.urlencode(data) with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response: return response.read() def get_history(prompt_id): with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response: return json.loads(response.read()) def get_images(ws,prompt,progress): progress=gr.Progress(track_tqdm=True) prompt_id = queue_prompt(prompt)['prompt_id'] output_images = {} last_reported_percentage = 0 while True: out = ws.recv() if isinstance(out, str): message = json.loads(out) if message['type'] == 'progress': data = message['data'] current_progress = data['value'] max_progress = data['max'] percentage = int((current_progress / max_progress) * 100) if percentage >= last_reported_percentage + 10: last_reported_percentage = percentage progress(percentage/100) elif message['type'] == 'executing': data = message['data'] if data['node'] is None and data['prompt_id'] == prompt_id: break # Execution is done else: continue # Previews are binary data history = get_history(prompt_id)[prompt_id] for o in history['outputs']: for node_id in history['outputs']: node_output = history['outputs'][node_id] if 'images' in node_output: images_output = [] for image in node_output['images']: image_data = get_image(image['filename'], image['subfolder'], image['type']) images_output.append(image_data) output_images[node_id] = images_output return output_images def pil_to_base64(image): buffer = io.BytesIO() image.save(buffer, format="PNG") base64_string=base64.b64encode(buffer.getvalue()).decode("utf-8") return f"data:image/png;base64,{base64_string}" def generate_images(positive_prompt,image,progress): ws = websocket.WebSocket() ws_url = f"ws://{server_address}/ws?clientId={client_id}" ws.connect(ws_url) data = json.loads(json_data) data["49"]["inputs"]["text"] = positive_prompt if image: data["90"]["inputs"]["images"]["base64"] = [pil_to_base64(image)] else: data.pop("90", None) data.pop("68", None) data["62"]["inputs"]["images"] = ["61",0] seed = random.randint(1, 1000000000) data["47"]["inputs"]["noise_seed"] = seed images = get_images(ws,data,progress) ws.close() for node_id in images: for image_data in images[node_id]: image = Image.open(io.BytesIO(image_data)) return image