Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from torchvision.models import resnet18, ResNet18_Weights | |
from PIL import Image | |
# 配置参数 | |
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"] | |
theme_color = "#6C5B7B" # 主色调改为优雅的紫色 | |
description = """<div style="padding: 20px; background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); border-radius: 10px;"> | |
<h2 style="color: {color}; margin-bottom: 15px;">🎨 NSFW 图片分类器</h2> | |
<p>该模型使用深度神经网络对图片内容进行分类,支持以下类别:</p> | |
<ul style="list-style-type: circle; padding-left: 25px;"> | |
<li><span style="color: #4B4453;">Drawings</span> - 艺术绘画作品</li> | |
<li><span style="color: #845EC2;">Hentai</span> - 二次元成人内容</li> | |
<li><span style="color: #008F7A;">Neutral</span> - 日常安全内容</li> | |
<li><span style="color: #D65DB1;">Porn</span> - 露骨成人内容</li> | |
<li><span style="color: #FF9671;">Sexy</span> - 性感但不露骨内容</li> | |
</ul> | |
<p style="margin-top: 15px;">🖼️ 请上传图片或点击下方示例体验</p> | |
</div>""".format(color=theme_color) | |
# 模型定义和预处理(保持不变) | |
# ... [保持原有模型代码不变] ... | |
# 高级 CSS 样式 | |
advanced_css = f""" | |
.gradio-container {{ | |
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); | |
min-height: 100vh; | |
}} | |
.header-section {{ | |
background: white; | |
padding: 2rem; | |
border-radius: 15px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.05); | |
margin-bottom: 2rem; | |
}} | |
.result-card {{ | |
background: white !important; | |
padding: 1.5rem !important; | |
border-radius: 12px !important; | |
box-shadow: 0 2px 8px rgba(108,91,123,0.1) !important; | |
}} | |
.custom-button {{ | |
background: {theme_color} !important; | |
color: white !important; | |
border: none !important; | |
padding: 12px 28px !important; | |
border-radius: 25px !important; | |
transition: all 0.3s ease !important; | |
}} | |
.custom-button:hover {{ | |
transform: translateY(-2px); | |
box-shadow: 0 4px 12px rgba(108,91,123,0.3) !important; | |
}} | |
.upload-box {{ | |
border: 2px dashed {theme_color} !important; | |
border-radius: 15px !important; | |
background: rgba(255,255,255,0.9) !important; | |
}} | |
.example-card {{ | |
cursor: pointer; | |
transition: all 0.3s ease; | |
border-radius: 12px; | |
overflow: hidden; | |
}} | |
.example-card:hover {{ | |
transform: scale(1.02); | |
box-shadow: 0 4px 12px rgba(108,91,123,0.2); | |
}} | |
.prob-bar {{ | |
height: 8px; | |
border-radius: 4px; | |
background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%); | |
}} | |
""" | |
# Define CNN model | |
class Classifier(nn.Module): | |
def __init__(self): | |
super(Classifier, self).__init__() | |
self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT) | |
self.fc_layers = nn.Sequential( | |
nn.Linear(1000, 512), | |
nn.Dropout(0.3), | |
nn.Linear(512, 128), | |
nn.ReLU(), | |
nn.Linear(128, 5), | |
) | |
def forward(self, x): | |
x = self.cnn_layers(x) | |
x = self.fc_layers(x) | |
return x | |
# Pre-process | |
preprocess = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Load model | |
model = Classifier() | |
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) | |
model.eval() | |
def predict(image_path): | |
img = Image.open(image_path).convert("RGB") | |
img = preprocess(img).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(model(img)[0], dim=0) | |
result = {labels[i]: float(prediction[i]) for i in range(5)} | |
return result | |
with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo: | |
# 标题区 | |
with gr.Column(elem_classes="header-section"): | |
gr.Markdown("# 🎭 智能内容识别系统", elem_id="main-title") | |
gr.HTML(description) | |
# 主功能区 | |
with gr.Row(): | |
# 输入列 | |
with gr.Column(scale=2): | |
upload_box = gr.Image( | |
type="filepath", | |
label="📤 上传图片", | |
elem_id="upload-box", | |
elem_classes="upload-box", | |
height=400 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button( | |
"✨ 开始分析", | |
elem_classes="custom-button", | |
size="lg" | |
) | |
clear_btn = gr.Button( | |
"🔄 重新上传", | |
variant="secondary", | |
size="lg" | |
) | |
# 输出列 | |
with gr.Column(scale=1): | |
with gr.Column(elem_classes="result-card"): | |
gr.Markdown("### 🔍 分析结果") | |
result_display = gr.Label( | |
label="分类概率分布", | |
num_top_classes=3, | |
show_label=False | |
) | |
gr.Markdown("**最高概率类别**: <span id='top-class'></span>", elem_id="dynamic-text") | |
# 示例区 | |
with gr.Column(): | |
gr.Markdown("### 🖼️ 示例图片") | |
examples = gr.Examples( | |
examples=["./example/anime.jpg", "./example/real.jpg"], | |
inputs=upload_box, | |
examples_per_page=2, | |
label="点击使用示例", | |
elem_id="example-gallery" | |
) | |
# 交互逻辑 | |
clear_btn.click(fn=lambda: None, inputs=None, outputs=upload_box) | |
submit_btn.click( | |
fn=predict, | |
inputs=upload_box, | |
outputs=result_display, | |
api_name="predict" | |
) | |
# 启动界面 | |
demo.launch() |