Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import numpy as np
|
|
5 |
import onnxruntime as rt
|
6 |
import pandas as pd
|
7 |
from PIL import Image
|
|
|
8 |
|
9 |
# 模型配置
|
10 |
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
|
@@ -12,6 +13,11 @@ MODEL_FILENAME = "model.onnx"
|
|
12 |
LABEL_FILENAME = "selected_tags.csv"
|
13 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
# 标签处理配置
|
16 |
kaomojis = [
|
17 |
"0_0",
|
@@ -40,34 +46,45 @@ class Tagger:
|
|
40 |
self.model = None
|
41 |
self.tag_names = []
|
42 |
self.model_size = None
|
|
|
43 |
self._init_model()
|
44 |
|
45 |
def _init_model(self):
|
46 |
"""初始化模型和标签"""
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def _preprocess(self, img):
|
73 |
"""图像预处理"""
|
|
|
5 |
import onnxruntime as rt
|
6 |
import pandas as pd
|
7 |
from PIL import Image
|
8 |
+
from huggingface_hub import login
|
9 |
|
10 |
# 模型配置
|
11 |
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
|
|
|
13 |
LABEL_FILENAME = "selected_tags.csv"
|
14 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
15 |
|
16 |
+
if not os.environ.get("HF_TOKEN"):
|
17 |
+
print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证")
|
18 |
+
else:
|
19 |
+
login(token=os.environ.get("HF_TOKEN"))
|
20 |
+
|
21 |
# 标签处理配置
|
22 |
kaomojis = [
|
23 |
"0_0",
|
|
|
46 |
self.model = None
|
47 |
self.tag_names = []
|
48 |
self.model_size = None
|
49 |
+
self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取
|
50 |
self._init_model()
|
51 |
|
52 |
def _init_model(self):
|
53 |
"""初始化模型和标签"""
|
54 |
+
try:
|
55 |
+
label_path = huggingface_hub.hf_hub_download(
|
56 |
+
MODEL_REPO,
|
57 |
+
LABEL_FILENAME,
|
58 |
+
token=self.hf_token
|
59 |
+
)
|
60 |
+
model_path = huggingface_hub.hf_hub_download(
|
61 |
+
MODEL_REPO,
|
62 |
+
MODEL_FILENAME,
|
63 |
+
token=self.hf_token
|
64 |
+
)
|
65 |
+
|
66 |
+
# 加载标签
|
67 |
+
tags_df = pd.read_csv(label_path)
|
68 |
+
self.tag_names = tags_df["name"].tolist()
|
69 |
+
self.categories = {
|
70 |
+
"rating": np.where(tags_df["category"] == 9)[0],
|
71 |
+
"general": np.where(tags_df["category"] == 0)[0],
|
72 |
+
"character": np.where(tags_df["category"] == 4)[0]
|
73 |
+
}
|
74 |
+
|
75 |
+
# 加载ONNX模型
|
76 |
+
self.model = rt.InferenceSession(model_path)
|
77 |
+
self.model_size = self.model.get_inputs()[0].shape[1]
|
78 |
+
except huggingface_hub.utils.HfHubHTTPError as e:
|
79 |
+
if "401" in str(e):
|
80 |
+
raise RuntimeError(
|
81 |
+
"模型下载认证失败,请:\n"
|
82 |
+
"1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n"
|
83 |
+
"2. 点击Agree and continue\n"
|
84 |
+
"3. 确保HF_TOKEN已正确设置"
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
raise
|
88 |
|
89 |
def _preprocess(self, img):
|
90 |
"""图像预处理"""
|