OysterQAQ commited on
Commit
67a7508
1 Parent(s): 00b432e

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +134 -1
README.md CHANGED
@@ -1,3 +1,136 @@
1
  ---
2
- license: mit
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ tags:
3
+ - vision
4
+ widget:
5
+ - src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat-dog-music.png
6
+ candidate_labels: playing music, playing sports
7
+ example_title: Cat & Dog
8
  ---
9
+
10
+ ### 介绍
11
+
12
+ 使用danburoo2021数据集对clip(ViT-L/14)模型进行微调。
13
+
14
+ 0-3 epoch学习率为4e-6,权重衰减为1e-3
15
+
16
+ 4-8 epoch学习率为1e-6,权重衰减为1e-3
17
+
18
+ 标签预处理过程:
19
+
20
+ ```python
21
+ for i in range(length):
22
+ # 加载并且缩放图片
23
+ if not is_image(data_from_db.path[i]):
24
+ continue
25
+
26
+ try:
27
+ img = self.preprocess(
28
+ Image.open(data_from_db.path[i].replace("./", "/mnt/lvm/danbooru2021/danbooru2021/")))
29
+ except Exception as e:
30
+ #print(e)
31
+ continue
32
+ # 处理标签
33
+ tags = json.loads(data_from_db.tags[i])
34
+ # 优先选择人物和作品标签
35
+ category_group = {}
36
+ for tag in tags:
37
+ category_group.setdefault(tag["category"], []).append(tag)
38
+
39
+ # category_group=groupby(tags, key=lambda x: (x["category"]))
40
+ character_list = category_group[4] if 4 in category_group else []
41
+ # 作品需要过滤以bad开头的
42
+
43
+ work_list = list(filter(
44
+ lambda e:
45
+ e["name"] != "original"
46
+ , category_group[3])) if 3 in category_group else []
47
+ # work_list= category_group[5] if 5 in category_group else []
48
+ general_list = category_group[0] if 0 in category_group else []
49
+ caption = ""
50
+ caption_2 = None
51
+ for character in character_list:
52
+ if len(work_list) != 0:
53
+ # 去除括号内作品内容
54
+ character["name"] = re.sub(u"\\(.*?\\)", "", character["name"])
55
+ caption += character["name"].replace("_", " ")
56
+ caption += ","
57
+ caption = caption[:-1]
58
+ caption += " "
59
+ if len(work_list) != 0:
60
+ caption += "from "
61
+ for work in work_list:
62
+ caption += work["name"].replace("_", " ")
63
+ caption += " "
64
+ # 普通标签
65
+ if len(general_list) != 0:
66
+ caption += "with "
67
+ if len(general_list) > 20:
68
+ general_list_1 = general_list[:int(len(general_list) / 2)]
69
+ general_list_2 = general_list[int(len(general_list) / 2):]
70
+ caption_2 = caption
71
+ for general in general_list_1:
72
+ if general["name"].find("girl") == -1 and general["name"].find("boy") == -1 and len(
73
+ re.findall(is_contain, general["name"])) != 0:
74
+ caption_2 += general["name"].replace("_", " ")
75
+ caption_2 += ","
76
+ caption_2 = caption_2[:-1]
77
+ for general in general_list_2:
78
+ if general["name"].find("girl") == -1 and general["name"].find("boy") == -1 and len(
79
+ re.findall(is_contain, general["name"])) != 0:
80
+ caption += general["name"].replace("_", " ")
81
+ caption += ","
82
+ caption = caption[:-1]
83
+ else:
84
+ for general in general_list:
85
+ # 如果标签数据目大于20 则拆分成两个caption
86
+ if general["name"].find("girl") == -1 and general["name"].find("boy") == -1 and len(
87
+ re.findall(is_contain, general["name"])) != 0:
88
+ caption += general["name"].replace("_", " ")
89
+ caption += ","
90
+ caption = caption[:-1]
91
+
92
+ # 标签汇总成语句
93
+ # tokenize语句
94
+ # 返回
95
+ # 过长截断 不行的话用huggingface的
96
+ text_1 = clip.tokenize(texts=caption, truncate=True)
97
+ text_2= None
98
+ if caption_2 is not None:
99
+ text_2 = clip.tokenize(texts=caption_2, truncate=True)
100
+ # 处理逻辑
101
+
102
+ # print(img)
103
+ yield img, text_1[0]
104
+ if text_2 is not None:
105
+ yield img, text_2[0]
106
+ ```
107
+
108
+
109
+ ### 使用
110
+
111
+ ```python
112
+ from PIL import Image
113
+ import requests
114
+
115
+ from transformers import CLIPProcessor, CLIPModel
116
+
117
+ model = CLIPModel.from_pretrained("OysterQAQ/DanbooruCLIP")
118
+ processor = CLIPProcessor.from_pretrained("OysterQAQ/DanbooruCLIP")
119
+
120
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
121
+ image = Image.open(requests.get(url, stream=True).raw)
122
+
123
+ inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
124
+
125
+ outputs = model(**inputs)
126
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
127
+ probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
128
+ ```
129
+
130
+
131
+
132
+ ## Feedback
133
+
134
+ ### Where to send questions or comments about the model
135
+
136
+ Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9)