infnapitoggle / app.py
charliebaby2023's picture
Update app.py
d36ccaa verified
import gradio as gr
import os
from huggingface_hub import InferenceClient, list_models
from diffusers import StableDiffusionXLPipeline
import torch
from PIL import Image
import traceback
# Load token from env
HF_TOKEN = os.getenv("ohgoddamn")
USE_LOCAL = False # default mode
client_cache = {}
# Your models (replace with yours)
all_models = [
"stabilityai/stable-diffusion-xl-base-1.0",
"runwayml/stable-diffusion-v1-5",
"Uthar/John6666_epicrealism-xl-v8kiss-sdxl"
]
# Local model loading (simplified, for demo)
def load_local_pipeline(model_id):
pipe = StableDiffusionXLPipeline.from_pretrained(
model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# Main generation logic
def generate(model_id, prompt, use_local):
global client_cache
debug_log = ""
try:
if use_local:
debug_log += f"🔧 Using local pipeline for: {model_id}\n"
pipe = load_local_pipeline(model_id)
image = pipe(prompt).images[0]
else:
debug_log += f"🌐 Using InferenceClient for: {model_id}\n"
if model_id not in client_cache:
client_cache[model_id] = InferenceClient(model=model_id, token=HF_TOKEN)
image = client_cache[model_id].text_to_image(prompt)
return image, debug_log + "\n✅ Success."
except Exception as e:
error_msg = traceback.format_exc()
return None, debug_log + f"\n❌ Error:\n{error_msg}"
# Model API self-check
def check_model_access(models):
results = ""
for model in models:
try:
client = InferenceClient(model=model, token=HF_TOKEN)
_ = client.text_to_image("test prompt")
results += f"✅ {model} is working.\n"
except Exception as e:
results += f"❌ {model} failed: {str(e).splitlines()[0]}\n"
return results
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🧪 Stable Diffusion API Tester")
with gr.Row():
model = gr.Dropdown(choices=all_models, label="Model", value=all_models[0])
use_local = gr.Checkbox(label="Use Local Diffusers Instead of API", value=USE_LOCAL)
prompt = gr.Textbox(label="Prompt", value="a cyberpunk cat playing guitar in Tokyo")
generate_btn = gr.Button("Generate")
image_out = gr.Image(label="Generated Image")
debug_out = gr.Textbox(label="Debug Output", lines=10)
with gr.Accordion("Self-Check: API Model Access", open=False):
check_btn = gr.Button("Check All Models")
check_results = gr.Textbox(label="Model API Status", lines=10)
generate_btn.click(generate, inputs=[model, prompt, use_local], outputs=[image_out, debug_out])
check_btn.click(check_model_access, inputs=[gr.State(all_models)], outputs=[check_results])
demo.launch()