|
import gradio as gr |
|
from model_pipelines import load_pipelines, generate_all |
|
from grace_eval import compute_sample_scores, plot_radar |
|
import torch |
|
import time |
|
from functools import partial |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
torch.backends.cuda.is_available = lambda: False |
|
|
|
class ModelLoader: |
|
_instance = None |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
cls._instance.models = None |
|
return cls._instance |
|
|
|
def load(self): |
|
if self.models is None: |
|
print("🔄 Initializing models...") |
|
start = time.time() |
|
self.models = load_pipelines() |
|
print(f"✅ Models loaded in {time.time()-start:.1f}s") |
|
return self.models |
|
|
|
def create_interface(): |
|
with gr.Blocks(title="🖼️ AI Image Generator Comparison", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("## 🏆 图像生成模型对比实验 (CPU模式)") |
|
|
|
with gr.Tab("🆚 Arena"): |
|
with gr.Row(): |
|
prompt = gr.Textbox(label="✨ 输入提示词", placeholder="描述您想生成的图像...") |
|
with gr.Row(): |
|
generate_btn = gr.Button("🚀 生成图像", variant="primary") |
|
with gr.Row(): |
|
outputs = [ |
|
gr.Image(label="Stable Diffusion v1.5", type="pil"), |
|
gr.Image(label="Openjourney v4", type="pil"), |
|
gr.Image(label="LDM 256", type="pil") |
|
] |
|
generate_btn.click( |
|
partial(generate_all, ModelLoader().load()), |
|
inputs=prompt, |
|
outputs=outputs |
|
) |
|
|
|
with gr.Tab("📊 Leaderboard"): |
|
with gr.Column(): |
|
eval_prompt = gr.Textbox(label="评估用提示词") |
|
eval_btn = gr.Button("生成雷达图") |
|
radar_img = gr.Image(label="GRACE评估结果") |
|
eval_btn.click( |
|
lambda p: (plot_radar(compute_sample_scores(None, p)) or "radar.png"), |
|
inputs=eval_prompt, |
|
outputs=radar_img |
|
) |
|
|
|
with gr.Tab("📝 Report"): |
|
try: |
|
with open("report.md", "r", encoding="utf-8") as f: |
|
gr.Markdown(f.read()) |
|
except: |
|
gr.Markdown("## 实验报告\n报告加载失败") |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
create_interface().launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
enable_queue=True |
|
) |