neggles commited on
Commit
680a411
·
1 Parent(s): 97e33a3
Files changed (2) hide show
  1. pyproject.toml +0 -1
  2. tagger/common.py +4 -32
pyproject.toml CHANGED
@@ -76,7 +76,6 @@ docstring-code-format = true
76
  [tool.ruff.lint.isort]
77
  combine-as-imports = true
78
  force-wrap-aliases = true
79
- known-local-folder = ["pi-tagger"]
80
  known-first-party = ["pi-tagger"]
81
 
82
 
 
76
  [tool.ruff.lint.isort]
77
  combine-as-imports = true
78
  force-wrap-aliases = true
 
79
  known-first-party = ["pi-tagger"]
80
 
81
 
tagger/common.py CHANGED
@@ -1,9 +1,7 @@
1
- import json
2
- from dataclasses import asdict, dataclass
3
  from functools import lru_cache
4
- from os import PathLike
5
  from pathlib import Path
6
- from typing import Any, Optional
7
 
8
  import numpy as np
9
  import pandas as pd
@@ -12,16 +10,8 @@ from huggingface_hub.utils import HfHubHTTPError
12
  from PIL import Image
13
 
14
 
15
- class DictJsonMixin:
16
- def asdict(self, *args, **kwargs) -> dict[str, Any]:
17
- return asdict(self, *args, **kwargs)
18
-
19
- def asjson(self, *args, **kwargs):
20
- return json.dumps(asdict(self, *args, **kwargs))
21
-
22
-
23
  @dataclass
24
- class LabelData(DictJsonMixin):
25
  names: list[str]
26
  rating: list[np.int64]
27
  general: list[np.int64]
@@ -29,7 +19,7 @@ class LabelData(DictJsonMixin):
29
 
30
 
31
  @dataclass
32
- class ImageLabels(DictJsonMixin):
33
  caption: str
34
  booru: str
35
  rating: dict[str, float]
@@ -37,24 +27,6 @@ class ImageLabels(DictJsonMixin):
37
  character: dict[str, float]
38
 
39
 
40
- @lru_cache(maxsize=5)
41
- def load_labels(version: str = "v3", data_dir: PathLike = "./data") -> LabelData:
42
- data_dir = Path(data_dir).resolve()
43
- csv_path = data_dir.joinpath(f"selected_tags_{version}.csv")
44
- if not csv_path.is_file():
45
- raise FileNotFoundError(f"{csv_path.name} not found in {data_dir}")
46
-
47
- df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
48
- tag_data = LabelData(
49
- names=df["name"].tolist(),
50
- rating=list(np.where(df["category"] == 9)[0]),
51
- general=list(np.where(df["category"] == 0)[0]),
52
- character=list(np.where(df["category"] == 4)[0]),
53
- )
54
-
55
- return tag_data
56
-
57
-
58
  @lru_cache(maxsize=5)
59
  def load_labels_hf(
60
  repo_id: str,
 
1
+ from dataclasses import dataclass
 
2
  from functools import lru_cache
 
3
  from pathlib import Path
4
+ from typing import Optional
5
 
6
  import numpy as np
7
  import pandas as pd
 
10
  from PIL import Image
11
 
12
 
 
 
 
 
 
 
 
 
13
  @dataclass
14
+ class LabelData:
15
  names: list[str]
16
  rating: list[np.int64]
17
  general: list[np.int64]
 
19
 
20
 
21
  @dataclass
22
+ class ImageLabels:
23
  caption: str
24
  booru: str
25
  rating: dict[str, float]
 
27
  character: dict[str, float]
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @lru_cache(maxsize=5)
31
  def load_labels_hf(
32
  repo_id: str,