IdlecloudX commited on
Commit
4412065
·
verified ·
1 Parent(s): 4b8ada2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -24
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import onnxruntime as rt
6
  import pandas as pd
7
  from PIL import Image
 
8
 
9
  # 模型配置
10
  MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
@@ -12,6 +13,11 @@ MODEL_FILENAME = "model.onnx"
12
  LABEL_FILENAME = "selected_tags.csv"
13
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
14
 
 
 
 
 
 
15
  # 标签处理配置
16
  kaomojis = [
17
  "0_0",
@@ -40,34 +46,45 @@ class Tagger:
40
  self.model = None
41
  self.tag_names = []
42
  self.model_size = None
 
43
  self._init_model()
44
 
45
  def _init_model(self):
46
  """初始化模型和标签"""
47
- # 下载模型文件
48
- label_path = huggingface_hub.hf_hub_download(
49
- MODEL_REPO,
50
- LABEL_FILENAME,
51
- token=HF_TOKEN
52
- )
53
- model_path = huggingface_hub.hf_hub_download(
54
- MODEL_REPO,
55
- MODEL_FILENAME,
56
- token=HF_TOKEN
57
- )
58
-
59
- # 加载标签
60
- tags_df = pd.read_csv(label_path)
61
- self.tag_names = tags_df["name"].tolist()
62
- self.categories = {
63
- "rating": np.where(tags_df["category"] == 9)[0],
64
- "general": np.where(tags_df["category"] == 0)[0],
65
- "character": np.where(tags_df["category"] == 4)[0]
66
- }
67
-
68
- # 加载ONNX模型
69
- self.model = rt.InferenceSession(model_path)
70
- self.model_size = self.model.get_inputs()[0].shape[1]
 
 
 
 
 
 
 
 
 
 
71
 
72
  def _preprocess(self, img):
73
  """图像预处理"""
 
5
  import onnxruntime as rt
6
  import pandas as pd
7
  from PIL import Image
8
+ from huggingface_hub import login
9
 
10
  # 模型配置
11
  MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" # 默认模型
 
13
  LABEL_FILENAME = "selected_tags.csv"
14
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
15
 
16
+ if not os.environ.get("HF_TOKEN"):
17
+ print("⚠️ 警告:未检测到HF_TOKEN,部分模型可能需要认证")
18
+ else:
19
+ login(token=os.environ.get("HF_TOKEN"))
20
+
21
  # 标签处理配置
22
  kaomojis = [
23
  "0_0",
 
46
  self.model = None
47
  self.tag_names = []
48
  self.model_size = None
49
+ self.hf_token = os.environ.get("HF_TOKEN", "") # 从环境变量获取
50
  self._init_model()
51
 
52
  def _init_model(self):
53
  """初始化模型和标签"""
54
+ try:
55
+ label_path = huggingface_hub.hf_hub_download(
56
+ MODEL_REPO,
57
+ LABEL_FILENAME,
58
+ token=self.hf_token
59
+ )
60
+ model_path = huggingface_hub.hf_hub_download(
61
+ MODEL_REPO,
62
+ MODEL_FILENAME,
63
+ token=self.hf_token
64
+ )
65
+
66
+ # 加载标签
67
+ tags_df = pd.read_csv(label_path)
68
+ self.tag_names = tags_df["name"].tolist()
69
+ self.categories = {
70
+ "rating": np.where(tags_df["category"] == 9)[0],
71
+ "general": np.where(tags_df["category"] == 0)[0],
72
+ "character": np.where(tags_df["category"] == 4)[0]
73
+ }
74
+
75
+ # 加载ONNX模型
76
+ self.model = rt.InferenceSession(model_path)
77
+ self.model_size = self.model.get_inputs()[0].shape[1]
78
+ except huggingface_hub.utils.HfHubHTTPError as e:
79
+ if "401" in str(e):
80
+ raise RuntimeError(
81
+ "模型下载认证失败,请:\n"
82
+ "1. 访问https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3\n"
83
+ "2. 点击Agree and continue\n"
84
+ "3. 确保HF_TOKEN已正确设置"
85
+ )
86
+ else:
87
+ raise
88
 
89
  def _preprocess(self, img):
90
  """图像预处理"""