apiTest / app.py
kenfoo's picture
Update app.py
80d18f3 verified
import gradio as gr
import random
import os
import requests
import base64
from PIL import Image
from io import BytesIO
HF_TOKEN = os.environ.get("girlToken")
API_BASE = "https://prithivmlmods-qwen-image-edit-2511-loras-fast.hf.space"
INFER_URL = f"{API_BASE}/gradio/infer"
NAMED_API_URL = f"{API_BASE}/gradio_api/call/v2/infer"
LORA_STYLES = [
'Multiple-Angles', 'Photo-to-Anime', 'Anime-V2', 'Light-Migration',
'Upscaler', 'Style-Transfer', 'Manga-Tone', 'Anything2Real',
'Fal-Multiple-Angles', 'Polaroid-Photo', 'Unblur-Anything',
'Midnight-Noir-Eyes-Spotlight', 'Hyper-Realistic-Portrait',
'Ultra-Realistic-Portrait', 'Pixar-Inspired-3D', 'Noir-Comic-Book',
'Any-light', 'Studio-DeLight', 'Cinematic-FlatLog',
]
MAX_SEED = 2**31 - 1
def encode_image_file_to_b64_payload(image_path):
"""
返回符合API要求的图片Payload对象(list,不是json字符串)。
"""
try:
with open(image_path, "rb") as f:
image_bytes = f.read()
# 如非jpeg,转码
try:
img = Image.open(BytesIO(image_bytes))
buffered = BytesIO()
img.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
except Exception:
pass
im_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = [
{
"data": im_b64,
"mime_type": "image/jpeg",
"orig_name": os.path.basename(image_path),
}
]
return payload
except Exception as e:
raise RuntimeError(f"图片编码失败: {e}")
def call_named_infer(
images_b64_payload,
prompt,
lora_adapter,
seed,
randomize_seed,
guidance_scale,
steps
):
headers = {
'Authorization': f'Bearer {HF_TOKEN}',
'Content-Type': 'application/json'
}
payload = {
"images_b64_json": images_b64_payload, # 注意此处改为直接传 list
"prompt": prompt,
"lora_adapter": lora_adapter,
"seed": int(seed),
"randomize_seed": bool(randomize_seed),
"guidance_scale": float(guidance_scale),
"steps": int(steps),
}
print("准备调用/infer:", {**payload, "images_b64_json": "[payload omitted for brevity]"})
import json
# 注意: 直接以json.dumps(payload)传body
resp = requests.post(NAMED_API_URL, data=json.dumps(payload), headers=headers)
resp.raise_for_status()
job = resp.json()
event_id = job.get("event_id")
return event_id
def poll_infer(event_id):
url = f"{API_BASE}/gradio_api/call/infer/{event_id}"
headers = {'Authorization': f'Bearer {HF_TOKEN}'}
import time
for i in range(60):
resp = requests.get(url, headers=headers)
try:
result = resp.json()
except Exception:
print(f"[轮询第{i+1}次] 响应无法decode,返回内容:{resp.text[:200]}")
time.sleep(2)
continue
if result.get("status") == "complete":
return result.get("data"), result.get("outputs")
elif result.get("status") == "error":
raise Exception(result.get("error"))
time.sleep(2)
raise TimeoutError("等候API返回超时")
def infer(
image,
prompt,
lora_adapter,
seed,
randomize_seed,
guidance_scale,
steps,
progress=gr.Progress(track_tqdm=True),
):
if image is None:
print("未上传图片")
return None, seed
if not os.path.exists(image):
print(f"图片路径不存在: {image}")
return None, seed
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
images_b64_payload = encode_image_file_to_b64_payload(image)
except Exception as e:
print(f"[图片 base64编码失败] {e}")
return None, seed
try:
event_id = call_named_infer(
images_b64_payload,
prompt,
lora_adapter,
seed,
randomize_seed,
guidance_scale,
steps
)
print("API返回event_id:", event_id)
data, outputs = poll_infer(event_id)
print("[API 完成] data:", data, "outputs:", outputs)
img_out = None
seed_used = seed
if outputs:
if isinstance(outputs, dict):
img_out = outputs.get("url") or outputs.get("path")
seed_used = outputs.get("seed", seed)
elif isinstance(outputs, str) and outputs.startswith("/"):
img_out = API_BASE + outputs
else:
img_out = outputs
elif data:
if isinstance(data, dict):
img_out = data.get("url") or data.get("path")
seed_used = data.get("seed", seed)
elif isinstance(data, str) and data.startswith("/"):
img_out = API_BASE + data
else:
img_out = data
if img_out and isinstance(img_out, str) and not img_out.startswith("http"):
img_out = API_BASE + img_out
return img_out, int(seed_used)
except Exception as e:
import traceback
traceback.print_exc()
print(f"[API 调用异常] {e}")
return None, seed
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# 图像编辑 Demo\n基于 prithivMLmods/Qwen-Image-Edit-2511-LoRAs-Fast (新版API)")
image = gr.Image(
label="上传图片",
sources=["upload"],
type="filepath",
)
prompt = gr.Text(
label="编辑描述(Prompt)",
placeholder="请输入图片编辑描述...",
)
lora_adapter = gr.Dropdown(
label="编辑风格(Style)",
choices=LORA_STYLES,
value="Photo-to-Anime"
)
run_button = gr.Button("执行编辑", variant="primary")
result = gr.Image(label="结果图片", show_label=True)
with gr.Accordion("高级设置", open=False):
seed = gr.Slider(
label="随机种子",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="随机化种子", value=True)
guidance_scale = gr.Slider(
label="引导强度 (Guidance Scale)",
minimum=1.0,
maximum=10.0,
step=0.1,
value=1.0,
)
steps = gr.Slider(
label="推理步数 (Steps)",
minimum=1,
maximum=50,
step=1,
value=4,
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[image, prompt, lora_adapter, seed, randomize_seed, guidance_scale, steps],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch(ssr_mode=False, share=True)