Upload README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,136 @@
|
|
1 |
---
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
tags:
|
3 |
+
- vision
|
4 |
+
widget:
|
5 |
+
- src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat-dog-music.png
|
6 |
+
candidate_labels: playing music, playing sports
|
7 |
+
example_title: Cat & Dog
|
8 |
---
|
9 |
+
|
10 |
+
### 介绍
|
11 |
+
|
12 |
+
使用danburoo2021数据集对clip(ViT-L/14)模型进行微调。
|
13 |
+
|
14 |
+
0-3 epoch学习率为4e-6,权重衰减为1e-3
|
15 |
+
|
16 |
+
4-8 epoch学习率为1e-6,权重衰减为1e-3
|
17 |
+
|
18 |
+
标签预处理过程:
|
19 |
+
|
20 |
+
```python
|
21 |
+
for i in range(length):
|
22 |
+
# 加载并且缩放图片
|
23 |
+
if not is_image(data_from_db.path[i]):
|
24 |
+
continue
|
25 |
+
|
26 |
+
try:
|
27 |
+
img = self.preprocess(
|
28 |
+
Image.open(data_from_db.path[i].replace("./", "/mnt/lvm/danbooru2021/danbooru2021/")))
|
29 |
+
except Exception as e:
|
30 |
+
#print(e)
|
31 |
+
continue
|
32 |
+
# 处理标签
|
33 |
+
tags = json.loads(data_from_db.tags[i])
|
34 |
+
# 优先选择人物和作品标签
|
35 |
+
category_group = {}
|
36 |
+
for tag in tags:
|
37 |
+
category_group.setdefault(tag["category"], []).append(tag)
|
38 |
+
|
39 |
+
# category_group=groupby(tags, key=lambda x: (x["category"]))
|
40 |
+
character_list = category_group[4] if 4 in category_group else []
|
41 |
+
# 作品需要过滤以bad开头的
|
42 |
+
|
43 |
+
work_list = list(filter(
|
44 |
+
lambda e:
|
45 |
+
e["name"] != "original"
|
46 |
+
, category_group[3])) if 3 in category_group else []
|
47 |
+
# work_list= category_group[5] if 5 in category_group else []
|
48 |
+
general_list = category_group[0] if 0 in category_group else []
|
49 |
+
caption = ""
|
50 |
+
caption_2 = None
|
51 |
+
for character in character_list:
|
52 |
+
if len(work_list) != 0:
|
53 |
+
# 去除括号内作品内容
|
54 |
+
character["name"] = re.sub(u"\\(.*?\\)", "", character["name"])
|
55 |
+
caption += character["name"].replace("_", " ")
|
56 |
+
caption += ","
|
57 |
+
caption = caption[:-1]
|
58 |
+
caption += " "
|
59 |
+
if len(work_list) != 0:
|
60 |
+
caption += "from "
|
61 |
+
for work in work_list:
|
62 |
+
caption += work["name"].replace("_", " ")
|
63 |
+
caption += " "
|
64 |
+
# 普通标签
|
65 |
+
if len(general_list) != 0:
|
66 |
+
caption += "with "
|
67 |
+
if len(general_list) > 20:
|
68 |
+
general_list_1 = general_list[:int(len(general_list) / 2)]
|
69 |
+
general_list_2 = general_list[int(len(general_list) / 2):]
|
70 |
+
caption_2 = caption
|
71 |
+
for general in general_list_1:
|
72 |
+
if general["name"].find("girl") == -1 and general["name"].find("boy") == -1 and len(
|
73 |
+
re.findall(is_contain, general["name"])) != 0:
|
74 |
+
caption_2 += general["name"].replace("_", " ")
|
75 |
+
caption_2 += ","
|
76 |
+
caption_2 = caption_2[:-1]
|
77 |
+
for general in general_list_2:
|
78 |
+
if general["name"].find("girl") == -1 and general["name"].find("boy") == -1 and len(
|
79 |
+
re.findall(is_contain, general["name"])) != 0:
|
80 |
+
caption += general["name"].replace("_", " ")
|
81 |
+
caption += ","
|
82 |
+
caption = caption[:-1]
|
83 |
+
else:
|
84 |
+
for general in general_list:
|
85 |
+
# 如果标签数据目大于20 则拆分成两个caption
|
86 |
+
if general["name"].find("girl") == -1 and general["name"].find("boy") == -1 and len(
|
87 |
+
re.findall(is_contain, general["name"])) != 0:
|
88 |
+
caption += general["name"].replace("_", " ")
|
89 |
+
caption += ","
|
90 |
+
caption = caption[:-1]
|
91 |
+
|
92 |
+
# 标签汇总成语句
|
93 |
+
# tokenize语句
|
94 |
+
# 返回
|
95 |
+
# 过长截断 不行的话用huggingface的
|
96 |
+
text_1 = clip.tokenize(texts=caption, truncate=True)
|
97 |
+
text_2= None
|
98 |
+
if caption_2 is not None:
|
99 |
+
text_2 = clip.tokenize(texts=caption_2, truncate=True)
|
100 |
+
# 处理逻辑
|
101 |
+
|
102 |
+
# print(img)
|
103 |
+
yield img, text_1[0]
|
104 |
+
if text_2 is not None:
|
105 |
+
yield img, text_2[0]
|
106 |
+
```
|
107 |
+
|
108 |
+
|
109 |
+
### 使用
|
110 |
+
|
111 |
+
```python
|
112 |
+
from PIL import Image
|
113 |
+
import requests
|
114 |
+
|
115 |
+
from transformers import CLIPProcessor, CLIPModel
|
116 |
+
|
117 |
+
model = CLIPModel.from_pretrained("OysterQAQ/DanbooruCLIP")
|
118 |
+
processor = CLIPProcessor.from_pretrained("OysterQAQ/DanbooruCLIP")
|
119 |
+
|
120 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
121 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
122 |
+
|
123 |
+
inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
124 |
+
|
125 |
+
outputs = model(**inputs)
|
126 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
127 |
+
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
128 |
+
```
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
## Feedback
|
133 |
+
|
134 |
+
### Where to send questions or comments about the model
|
135 |
+
|
136 |
+
Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9)
|