pyx9913
commited on
Commit
โข
aa60bbf
1
Parent(s):
4f1e38f
feat: ๐ธ add chat model code
Browse files- README.md +124 -69
- README_en.md +162 -0
- beit3.py +108 -0
- config.json +27 -0
- configuration_viscpmchatbee.py +133 -0
- feature_extraction_viscpmchatbee.py +17 -0
- generation_config.json +12 -0
- modeling_cpmbee.py +0 -0
- preprocessor_config.json +10 -0
- processing_viscpmchatbee.py +428 -0
- tokenization_viscpmchatbee.py +1007 -0
- tokenizer_config.json +10 -0
- utils.py +730 -0
- vocab.txt +0 -0
README.md
CHANGED
@@ -1,58 +1,45 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
- en
|
4 |
-
- zh
|
5 |
-
---
|
6 |
-
<div align="center">
|
7 |
-
|
8 |
-
**VisCPM**
|
9 |
-
|
10 |
-
**Chinese-English bilingual multi-modal large model series based on CPM (Chinese Pretrained Models) basic model**
|
11 |
|
12 |
<p align="center">
|
13 |
-
|
14 |
-
|
|
|
15 |
</p>
|
16 |
|
17 |
-
|
18 |
|
19 |
-
`VisCPM
|
20 |
-
|
21 |
-
- **๐ Open-source Usage**: VisCPM is free to be used for personal and research purposes. By open-sourcing the VisCPM model family, we hope to promote the development of the open-source community of large multimodal models and related research.
|
22 |
-
- **๐ Image and text generation coverage**: VisCPM models provide relatively comprehensive support for image and text multimodal capabilities, covering both multimodal conversation (image-to-text generation) capabilities and text-to-image generation capabilities.
|
23 |
-
- **๐ซ Excellent bilingual performance**: Thanks to the excellent bilingual capability of the base language model CPM-Bee, VisCPM achieves outstanding results in both bilingual multimodal conversation and text-to-image generation.
|
24 |
|
25 |
## VisCPM-Chat
|
26 |
-
`VisCPM-Chat
|
27 |
|
28 |
-
*
|
29 |
|
30 |
-
*
|
31 |
|
32 |
-
|
33 |
|
34 |
<table>
|
35 |
<tr>
|
36 |
-
<td align="center" rowspan="2" colspan="2"
|
37 |
-
<td align="center"
|
38 |
-
<td align="center" colspan="4"
|
39 |
-
<td align="center" colspan="4">Chinese</td>
|
40 |
</tr>
|
41 |
<tr>
|
42 |
-
<td align="center"
|
43 |
-
<td align="center"
|
44 |
-
<td align="center"
|
45 |
-
<td align="center"
|
46 |
-
<td align="center"
|
47 |
-
<td align="center"
|
48 |
-
<td align="center"
|
49 |
-
<td align="center"
|
50 |
</tr>
|
51 |
<tr>
|
52 |
-
<td align="center" rowspan="3"
|
53 |
<td align="center">MiniGPT4</td>
|
54 |
-
<td align="center">
|
55 |
-
<td align="center">65.0</td>
|
56 |
<td align="center">67.3</td>
|
57 |
<td align="center">76.6</td>
|
58 |
<td align="center">69.7</td>
|
@@ -63,9 +50,8 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
|
|
63 |
</tr>
|
64 |
<tr>
|
65 |
<td align="center">InstructBLIP</td>
|
66 |
-
<td align="center">Vicuna-13B</td>
|
67 |
<td align="center">81.9</td>
|
68 |
-
<td align="center">68
|
69 |
<td align="center">91.2</td>
|
70 |
<td align="center">80.5</td>
|
71 |
<td align="center">-</td>
|
@@ -75,20 +61,18 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
|
|
75 |
</tr>
|
76 |
<tr>
|
77 |
<td align="center">LLaVA</td>
|
78 |
-
<td align="center">
|
79 |
-
<td align="center"
|
80 |
-
<td align="center"
|
81 |
-
<td align="center"
|
82 |
-
<td align="center"><b>85.6</b></td>
|
83 |
<td align="center">-</td>
|
84 |
<td align="center">-</td>
|
85 |
<td align="center">-</td>
|
86 |
<td align="center">-</td>
|
87 |
</tr>
|
88 |
<tr>
|
89 |
-
<td align="center" rowspan="
|
90 |
<td align="center">mPLUG-Owl </td>
|
91 |
-
<td align="center">LLaMA-7B</td>
|
92 |
<td align="center">64.6</td>
|
93 |
<td align="center">47.7</td>
|
94 |
<td align="center">80.1</td>
|
@@ -96,61 +80,132 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
|
|
96 |
<td align="center">76.3</td>
|
97 |
<td align="center">61.2</td>
|
98 |
<td align="center">77.8</td>
|
99 |
-
<td align="center">72
|
100 |
</tr>
|
101 |
<tr>
|
102 |
<td align="center">VisualGLM</td>
|
103 |
-
<td align="center">ChatGLM-6B</td>
|
104 |
<td align="center">62.4</td>
|
105 |
-
<td align="center">63
|
106 |
<td align="center">80.6</td>
|
107 |
<td align="center">68.7</td>
|
108 |
<td align="center">76.6</td>
|
109 |
-
<td align="center"
|
110 |
<td align="center">83.6</td>
|
111 |
<td align="center">82.7</td>
|
112 |
</tr>
|
113 |
<tr>
|
114 |
-
<td align="center">Ziya
|
115 |
-
<td align="center">Ziya-LLaMA-13B-v1</td>
|
116 |
<td align="center">82.7</td>
|
117 |
<td align="center">69.9</td>
|
118 |
<td align="center">92.1</td>
|
119 |
<td align="center">81.7</td>
|
120 |
-
<td align="center">85
|
121 |
<td align="center">74.7</td>
|
122 |
<td align="center">82.4</td>
|
123 |
<td align="center">80.8</td>
|
124 |
</tr>
|
125 |
<tr>
|
126 |
-
<td align="center">VisCPM-Chat
|
127 |
-
<td align="center">CPMBee-10B</td>
|
128 |
<td align="center">83.3</td>
|
129 |
<td align="center">68.9</td>
|
130 |
<td align="center">90.5</td>
|
131 |
<td align="center">81.1</td>
|
132 |
-
<td align="center"
|
133 |
<td align="center">76.1</td>
|
134 |
<td align="center">89.2</td>
|
135 |
<td align="center">86.3</td>
|
136 |
</tr>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
<tr>
|
138 |
-
<td align="center">
|
139 |
-
<td align="center">
|
140 |
-
<td align="center">
|
141 |
-
<td align="center">
|
142 |
-
|
143 |
-
|
144 |
-
<td align="center">
|
145 |
-
<td align="center">
|
146 |
-
<td align="center"
|
147 |
-
<td align="center"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
</tr>
|
149 |
</table>
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
|
|
|
|
|
153 |
|
154 |
-
|
|
|
|
|
155 |
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VisCPM
|
2 |
+
็ฎไฝไธญๆ | [English](README_en.md)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
<p align="center">
|
5 |
+
<p align="left">
|
6 |
+
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
|
7 |
+
<a href=""><img src="https://img.shields.io/badge/python-3.8+-aff.svg"></a>
|
8 |
</p>
|
9 |
|
10 |
+
`VisCPM` is a family of open-source large multimodal models, which support multimodal conversational capabilities (`VisCPM-Chat` model) and text-to-image generation capabilities (`VisCPM-Paint` model) in both Chinese and English, achieving state-of-the-art peformance among Chinese open-source multimodal models. `VisCPM` is trained based on the large language model [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) with 10B parameters, fusing visual encoder (Q-Former) and visual decoder (Diffusion-UNet) to support visual inputs and outputs. Thanks to the good bilingual capability of CPM-Bee, `VisCPM` can be pre-trained with English multimodal data only and well generalize to achieve promising Chinese multimodal capabilities.
|
11 |
|
12 |
+
`VisCPM`ๆฏไธไธชๅผๆบ็ๅคๆจกๆๅคงๆจกๅ็ณปๅ๏ผๆฏๆไธญ่ฑๅ่ฏญ็ๅคๆจกๆๅฏน่ฏ่ฝๅ๏ผ`VisCPM-Chat`ๆจกๅ๏ผๅๆๅฐๅพ็ๆ่ฝๅ๏ผ`VisCPM-Paint`ๆจกๅ๏ผ๏ผๅจไธญๆๅคๆจกๆๅผๆบๆจกๅไธญ่พพๅฐๆไฝณๆฐดๅนณใ`VisCPM`ๅบไบ็พไบฟๅๆฐ้่ฏญ่จๅคงๆจกๅ[CPM-Bee](https://github.com/OpenBMB/CPM-Bee)๏ผ10B๏ผ่ฎญ็ป๏ผ่ๅ่ง่ง็ผ็ ๅจ๏ผ`Q-Former`๏ผๅ่ง่ง่งฃ็ ๅจ๏ผ`Diffusion-UNet`๏ผไปฅๆฏๆ่ง่งไฟกๅท็่พๅ
ฅๅ่พๅบใๅพ็ไบ`CPM-Bee`ๅบๅบงไผ็ง็ๅ่ฏญ่ฝๅ๏ผ`VisCPM`ๅฏไปฅไป
้่ฟ่ฑๆๅคๆจกๆๆฐๆฎ้ข่ฎญ็ป๏ผๆณๅๅฎ็ฐไผ็ง็ไธญๆๅคๆจกๆ่ฝๅใ
|
|
|
|
|
|
|
|
|
13 |
|
14 |
## VisCPM-Chat
|
15 |
+
`VisCPM-Chat`ๆฏๆ้ขๅๅพๅ่ฟ่กไธญ่ฑๅ่ฏญๅคๆจกๆๅฏน่ฏใ่ฏฅๆจกๅไฝฟ็จ`Q-Former`ไฝไธบ่ง่ง็ผ็ ๅจ๏ผไฝฟ็จCPM-Bee๏ผ10B๏ผไฝไธบ่ฏญ่จไบคไบๅบๅบๆจกๅ๏ผๅนถ้่ฟ่ฏญ่จๅปบๆจก่ฎญ็ป็ฎๆ ่ๅ่ง่งๅ่ฏญ่จๆจกๅใๆจกๅ่ฎญ็ปๅ
ๆฌ้ข่ฎญ็ปๅๆไปค็ฒพ่ฐไธค้ถๆฎต๏ผ
|
16 |
|
17 |
+
* ้ข่ฎญ็ป๏ผๆไปฌไฝฟ็จ็บฆ100M้ซ่ดจ้่ฑๆๅพๆๅฏนๆฐๆฎๅฏน`VisCPM-Chat`่ฟ่กไบ้ข่ฎญ็ป๏ผๆฐๆฎๅ
ๆฌCC3MใCC12MใCOCOใVisual GenomeใLaion็ญใๅจ้ข่ฎญ็ป้ถๆฎต๏ผ่ฏญ่จๆจกๅๅๆฐไฟๆๅบๅฎ๏ผไป
ๆดๆฐ`Q-Former`้จๅๅๆฐ๏ผไปฅๆฏๆๅคง่งๆจก่ง่ง-่ฏญ่จ่กจ็คบ็้ซๆๅฏน้ฝใ
|
18 |
|
19 |
+
* ๆไปค็ฒพ่ฐ๏ผๆไปฌ้็จ[LLaVA-150K](https://llava-vl.github.io/)่ฑๆๆไปค็ฒพ่ฐๆฐๆฎ๏ผๅนถๆททๅ็ธๅบ็ฟป่ฏๅ็ไธญๆๆฐๆฎๅฏนๆจกๅ่ฟ่กๆไปค็ฒพ่ฐ๏ผไปฅๅฏน้ฝๆจกๅๅคๆจกๆๅบ็ก่ฝๅๅ็จๆทไฝฟ็จๆๅพใๅจๆไปค็ฒพ่ฐ้ถๆฎต๏ผๆไปฌๆดๆฐๅ
จ้จๆจกๅๅๆฐ๏ผไปฅๆๅๆไปค็ฒพ่ฐๆฐๆฎ็ๅฉ็จๆ็ใๆ่ถฃ็ๆฏ๏ผๆไปฌๅ็ฐๅณไฝฟไป
้็จ่ฑๆๆไปคๆฐๆฎ่ฟ่กๆไปค็ฒพ่ฐ๏ผๆจกๅไนๅฏไปฅ็่งฃไธญๆ้ฎ้ข๏ผไฝไป
่ฝ็จ่ฑๆๅ็ญใ่ฟ่กจๆๆจกๅ็ๅค่ฏญ่จๅคๆจกๆ่ฝๅๅทฒ็ปๅพๅฐ่ฏๅฅฝ็ๆณๅใๅจๆไปค็ฒพ่ฐ้ถๆฎต่ฟไธๆญฅๅ ๅ
ฅๅฐ้ไธญๆ็ฟป่ฏๆฐๆฎ๏ผๅฏไปฅๅฐๆจกๅๅๅค่ฏญ่จๅ็จๆท้ฎ้ข่ฏญ่จๅฏน้ฝใ
|
20 |
|
21 |
+
ๆไปฌๅจLLaVA่ฑๆๆต่ฏ้ๅ็ฟป่ฏ็ไธญๆๆต่ฏ้ๅฏนๆจกๅ่ฟ่กไบ่ฏๆต๏ผ่ฏฅ่ฏๆตๅบๅ่ๅฏๆจกๅๅจๅผๆพๅๅฏน่ฏใๅพๅ็ป่ๆ่ฟฐใๅคๆๆจ็ๆน้ข็่กจ็ฐ๏ผๅนถไฝฟ็จGPT-4่ฟ่กๆๅใๅฏไปฅ่งๅฏๅฐ๏ผ`VisCPM-Chat`ๅจไธญๆๅคๆจกๆ่ฝๅๆน้ขๅๅพไบๆไฝณ็ๅนณๅๆง่ฝ๏ผๅจ้็จๅๅฏน่ฏๅๅคๆๆจ็่กจ็ฐๅบ่ฒ๏ผๅๆถไน่กจ็ฐๅบไบไธ้็่ฑๆๅคๆจกๆ่ฝๅใ
|
22 |
|
23 |
<table>
|
24 |
<tr>
|
25 |
+
<td align="center" rowspan="2" colspan="2">ๆจกๅ</td>
|
26 |
+
<td align="center" colspan="4">่ฑๆ</td>
|
27 |
+
<td align="center" colspan="4">ไธญๆ</td>
|
|
|
28 |
</tr>
|
29 |
<tr>
|
30 |
+
<td align="center">ๅคๆจกๆๅฏน่ฏ</td>
|
31 |
+
<td align="center">็ป่ๆ่ฟฐ</td>
|
32 |
+
<td align="center">ๅคๆๆจ็</td>
|
33 |
+
<td align="center">ๅนณๅ</td>
|
34 |
+
<td align="center">ๅคๆจกๆๅฏน่ฏ</td>
|
35 |
+
<td align="center">็ป่ๆ่ฟฐ</td>
|
36 |
+
<td align="center">ๅคๆๆจ็</td>
|
37 |
+
<td align="center">ๅนณๅ</td>
|
38 |
</tr>
|
39 |
<tr>
|
40 |
+
<td align="center" rowspan="3">่ฑๆๆจกๅ</td>
|
41 |
<td align="center">MiniGPT4</td>
|
42 |
+
<td align="center">65</td>
|
|
|
43 |
<td align="center">67.3</td>
|
44 |
<td align="center">76.6</td>
|
45 |
<td align="center">69.7</td>
|
|
|
50 |
</tr>
|
51 |
<tr>
|
52 |
<td align="center">InstructBLIP</td>
|
|
|
53 |
<td align="center">81.9</td>
|
54 |
+
<td align="center">68</td>
|
55 |
<td align="center">91.2</td>
|
56 |
<td align="center">80.5</td>
|
57 |
<td align="center">-</td>
|
|
|
61 |
</tr>
|
62 |
<tr>
|
63 |
<td align="center">LLaVA</td>
|
64 |
+
<td align="center">89.5</td>
|
65 |
+
<td align="center">70.4</td>
|
66 |
+
<td align="center">96.2</td>
|
67 |
+
<td align="center">85.6</td>
|
|
|
68 |
<td align="center">-</td>
|
69 |
<td align="center">-</td>
|
70 |
<td align="center">-</td>
|
71 |
<td align="center">-</td>
|
72 |
</tr>
|
73 |
<tr>
|
74 |
+
<td align="center" rowspan="4">ไธญ่ฑๅ่ฏญ</td>
|
75 |
<td align="center">mPLUG-Owl </td>
|
|
|
76 |
<td align="center">64.6</td>
|
77 |
<td align="center">47.7</td>
|
78 |
<td align="center">80.1</td>
|
|
|
80 |
<td align="center">76.3</td>
|
81 |
<td align="center">61.2</td>
|
82 |
<td align="center">77.8</td>
|
83 |
+
<td align="center">72</td>
|
84 |
</tr>
|
85 |
<tr>
|
86 |
<td align="center">VisualGLM</td>
|
|
|
87 |
<td align="center">62.4</td>
|
88 |
+
<td align="center">63</td>
|
89 |
<td align="center">80.6</td>
|
90 |
<td align="center">68.7</td>
|
91 |
<td align="center">76.6</td>
|
92 |
+
<td align="center">87.8</td>
|
93 |
<td align="center">83.6</td>
|
94 |
<td align="center">82.7</td>
|
95 |
</tr>
|
96 |
<tr>
|
97 |
+
<td align="center">Ziya (LLaMA 13B)</td>
|
|
|
98 |
<td align="center">82.7</td>
|
99 |
<td align="center">69.9</td>
|
100 |
<td align="center">92.1</td>
|
101 |
<td align="center">81.7</td>
|
102 |
+
<td align="center">85</td>
|
103 |
<td align="center">74.7</td>
|
104 |
<td align="center">82.4</td>
|
105 |
<td align="center">80.8</td>
|
106 |
</tr>
|
107 |
<tr>
|
108 |
+
<td align="center">VisCPM-Chat</td>
|
|
|
109 |
<td align="center">83.3</td>
|
110 |
<td align="center">68.9</td>
|
111 |
<td align="center">90.5</td>
|
112 |
<td align="center">81.1</td>
|
113 |
+
<td align="center">92.7</td>
|
114 |
<td align="center">76.1</td>
|
115 |
<td align="center">89.2</td>
|
116 |
<td align="center">86.3</td>
|
117 |
</tr>
|
118 |
+
</table>
|
119 |
+
|
120 |
+
## VisCPM-Paint
|
121 |
+
`VisCPM-Paint`ๆฏๆไธญ่ฑๅ่ฏญ็ๆๅฐๅพ็ๆใ่ฏฅๆจกๅไฝฟ็จCPM-Bee๏ผ10B๏ผไฝไธบๆๆฌ็ผ็ ๅจ๏ผไฝฟ็จ`UNet`ไฝไธบๅพๅ่งฃ็ ๅจ๏ผๅนถ้่ฟๆฉๆฃๆจกๅ่ฎญ็ป็ฎๆ ่ๅ่ฏญ่จๅ่ง่งๆจกๅใๅจ่ฎญ็ป่ฟ็จไธญ๏ผ่ฏญ่จๆจกๅๅๆฐๅง็ปไฟๆๅบๅฎใๆไปฌไฝฟ็จ[Stable Diffusion 2.1](https://github.com/Stability-AI/stablediffusion)็UNetๅๆฐๅๅงๅ่ง่ง่งฃ็ ๅจ๏ผๅนถ้่ฟ้ๆญฅ่งฃๅปๅ
ถไธญๅ
ณ้ฎ็ๆกฅๆฅๅๆฐๅฐๅ
ถไธ่ฏญ่จๆจกๅ่ๅ๏ผ้ฆๅ
่ฎญ็ปๆๆฌ่กจ็คบๆ ๅฐๅฐ่ง่งๆจกๅ็็บฟๆงๅฑ๏ผ็ถๅ่ฟไธๆญฅ่งฃๅป`UNet`็ไบคๅๆณจๆๅๅฑใ่ฏฅๆจกๅๅจ[LAION 2B](https://laion.ai/)่ฑๆๅพๆๅฏนๆฐๆฎไธ่ฟ่กไบ่ฎญ็ปใ
|
122 |
+
|
123 |
+
ไธ`VisCPM-Chat`็ฑปไผผ๏ผๆไปฌๅ็ฐๅพ็ไบCPM-Bee็ๅ่ฏญ่ฝๅ๏ผ`VisCPM-Paint`ๅฏไปฅไป
้่ฟ่ฑๆๅพๆๅฏน่ฎญ็ป๏ผๆณๅๅฎ็ฐ่ฏๅฅฝ็ไธญๆๆๅฐๅพ็ๆ่ฝๅ๏ผ่พพๅฐไธญๆๅผๆบๆจกๅ็ๆไฝณๆๆใ้่ฟ่ฟไธๆญฅๅ ๅ
ฅ20Mๆธ
ๆดๅ็ๅ็ไธญๆๅพๆๅฏนๆฐๆฎ๏ผไปฅๅ120M็ฟป่ฏๅฐไธญๆ็ๅพๆๅฏนๆฐๆฎ๏ผๆจกๅ็ไธญๆๆๅฐๅพ็ๆ่ฝๅๅฏไปฅ่ทๅพ่ฟไธๆญฅๆๅใๆไปฌๅจMSCOCOไธ้ๆ ทไบ3ไธๅผ ๅพ็๏ผ่ฎก็ฎไบFID(Frรฉchet Inception Distance)ๅClip Score๏ผๅ่
็จไบ่ฏไผฐ็ๆๅพ็็่ดจ้๏ผๅ้ข็จไบ่ฏไผฐ็ๆ็ๅพ็ไธ่พๅ
ฅ็ๅน้
็จๅบฆใ
|
124 |
+
|
125 |
+
<table>
|
126 |
+
<tr>
|
127 |
+
<td align="center" rowspan="2">ๆจกๅ</td>
|
128 |
+
<td align="center" colspan="2">่ฑๆ</td>
|
129 |
+
<td align="center" colspan="2">ไธญๆ</td>
|
130 |
+
</tr>
|
131 |
<tr>
|
132 |
+
<td align="center">FIDโ</td>
|
133 |
+
<td align="center">CLIP Scoreโ</td>
|
134 |
+
<td align="center">FIDโ</td>
|
135 |
+
<td align="center">CLIP Scoreโ</td>
|
136 |
+
</tr>
|
137 |
+
<tr>
|
138 |
+
<td align="center">AltDiffusion</td>
|
139 |
+
<td align="center">17.16</td>
|
140 |
+
<td align="center">25.24</td>
|
141 |
+
<td align="center">16.09</td>
|
142 |
+
<td align="center">24.05</td>
|
143 |
+
</tr>
|
144 |
+
<tr>
|
145 |
+
<td align="center">TaiyiDiffusion</td>
|
146 |
+
<td align="center">-</td>
|
147 |
+
<td align="center">-</td>
|
148 |
+
<td align="center">15.58</td>
|
149 |
+
<td align="center">22.69</td>
|
150 |
+
</tr>
|
151 |
+
<tr>
|
152 |
+
<td align="center">Stable Diffusion</td>
|
153 |
+
<td align="center">9.08</td>
|
154 |
+
<td align="center">26.22</td>
|
155 |
+
<td align="center">-</td>
|
156 |
+
<td align="center">-</td>
|
157 |
+
</tr>
|
158 |
+
<tr>
|
159 |
+
<td align="center">VisCPM-Paint-en</td>
|
160 |
+
<td align="center">9.51</td>
|
161 |
+
<td align="center">25.35</td>
|
162 |
+
<td align="center">10.86</td>
|
163 |
+
<td align="center">23.38</td>
|
164 |
+
</tr>
|
165 |
+
<tr>
|
166 |
+
<td align="center">VisCPM-Paint-zh</td>
|
167 |
+
<td align="center">9.98</td>
|
168 |
+
<td align="center">25.04</td>
|
169 |
+
<td align="center">9.65</td>
|
170 |
+
<td align="center">24.17</td>
|
171 |
</tr>
|
172 |
</table>
|
173 |
|
174 |
+
# ๅฎ่ฃ
|
175 |
+
|
176 |
+
```Shell
|
177 |
+
conda create -n viscpm python=3.10 -y
|
178 |
+
conda activate viscpm
|
179 |
+
pip install setuptools
|
180 |
+
pip install diffusers jieba matplotlib numpy opencv_python
|
181 |
+
pip install pandas Pillow psutil pydantic scipy
|
182 |
+
pip install torch==1.13.1 torchscale==0.2.0 torchvision==0.14.1 timm
|
183 |
+
pip install transformers==4.28.0
|
184 |
+
pip install tqdm typing_extensions
|
185 |
+
pip install git+https://github.com/thunlp/OpenDelta.git
|
186 |
+
pip install git+https://github.com/OpenBMB/CPM-Bee.git#egg=cpm-live&subdirectory=src
|
187 |
+
```
|
188 |
+
|
189 |
+
VisCPM้่ฆๅๅก40GBไปฅไธ็GPU่ฟ่ก๏ผๆไปฌไผๅจๅฐฝๅฟซๆดๆฐๆดๅ ่็ๆพๅญ็ๆจ็ๆนๅผใ
|
190 |
+
|
191 |
+
## ไฝฟ็จ
|
192 |
|
193 |
+
```python
|
194 |
+
>>> from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
195 |
+
>>> from PIL import Image
|
196 |
|
197 |
+
>>> tokenizer = AutoTokenizer.from_pretrained('viscpm', trust_remote_code=True)
|
198 |
+
>>> processor = AutoImageProcessor.from_pretrained('viscpm', trust_remote_code=True)
|
199 |
+
>>> model = AutoModel.from_pretrained('viscpm', trust_remote_code=True).to('cuda')
|
200 |
|
201 |
+
>>> data = [{
|
202 |
+
>>> 'context': '',
|
203 |
+
>>> 'question': 'describe this image in detail.',
|
204 |
+
>>> 'image': tokenizer.unk_token * model.query_num,
|
205 |
+
>>> '<ans>': ''
|
206 |
+
>>> }]
|
207 |
+
>>> image = Image.open('case.jpg')
|
208 |
+
>>> result = model.generate(data, tokenizer, processor, image)
|
209 |
+
>>> print(result[0]['<ans>'])
|
210 |
+
่ฟๅน
ๅพ็ๆพ็คบไบไธ็พค็ญๆฐ็ๅจๅคฉ็ฉบไธญ้ฃ่กใ่ฟไบ็ญๆฐ็ๆผๆตฎๅจไธๅ็ๅฐๆน๏ผๅ
ๆฌๅฑฑ่ใๅๅธๅไนกๆๅฐๅบใ
|
211 |
+
```
|
README_en.md
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VisCPM
|
2 |
+
[็ฎไฝไธญๆ](README.md) | English
|
3 |
+
|
4 |
+
<p align="center">
|
5 |
+
<p align="left">
|
6 |
+
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
|
7 |
+
<a href=""><img src="https://img.shields.io/badge/python-3.8+-aff.svg"></a>
|
8 |
+
</p>
|
9 |
+
|
10 |
+
`VisCPM` is a family of open-source large multimodal models, which support multimodal conversational capabilities (`VisCPM-Chat` model) and text-to-image generation capabilities (`VisCPM-Paint` model) in both Chinese and English, achieving state-of-the-art peformance among Chinese open-source multimodal models. `VisCPM` is trained based on the large language model [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) with 10B parameters, fusing visual encoder (`Q-Former`) and visual decoder (`Diffusion-UNet`) to support visual inputs and outputs. Thanks to the good bilingual capability of `CPM-Bee`, `VisCPM` can be pre-trained with English multimodal data only and well generalize to achieve promising Chinese multimodal capabilities.
|
11 |
+
|
12 |
+
## VisCPM-Chat
|
13 |
+
`VisCPM-Chat` supports bilingual multimodal conversations involving images in both Chinese and English. The model utilizes `Q-Former` as the visual encoder and CPM-Bee (10B) as the base LLM. It combines visual and language models through language modeling training objectives. The model training consists of two stages: pretraining and instruction fine-tuning.
|
14 |
+
|
15 |
+
* Pretrain: `VisCPM-Chat` was pretrained using approximately 100 million high-quality English multimodal data pairs. The data sources include CC3M, CC12M, COCO, Visual Genome, Laion, and others. In this stage, the language model parameters remain fixed, and only the parameters of the `Q-Former` are updated to enable efficient alignment of large-scale visual-language representations.
|
16 |
+
|
17 |
+
* Instruction fine-tuning: We utilized the [LLaVA-150K](https://llava-vl.github.io/) dataset, which consists of English multimodal instruction-following dataset. We mixed this data with corresponding translated Chinese data to fine-tune the model and align its multimodal capabilities with user intents. In this phase, we updated all model parameters to improve the utilization efficiency of the instruction fine-tuning data. Interestingly, we observed that even when using only English instruction data for fine-tuning, the model can comprehend Chinese questions but can only respond in English. This indicates that the model has achieved good generalization in terms of its multilingual and multimodal capabilities. By incorporating a small amount of translated Chinese data during the instruction fine-tuning phase, we can align the model's response language with the user's question language.
|
18 |
+
|
19 |
+
We evaluated the model on the LLaVA English test set and the translated Chinese test set. The evaluation benchmark examined the model's performance in open-domain conversations, image detail descriptions, and complex reasoning tasks, using GPT-4 for scoring. It is evident that `VisCPM-Chat` achieved the best average performance in Chinese multimodal capabilities, excelling in general-domain conversations and complex reasoning. It also demonstrated commendable English multimodal abilities.
|
20 |
+
|
21 |
+
<table>
|
22 |
+
<tr>
|
23 |
+
<td align="center" rowspan="2" colspan="2">Model</td>
|
24 |
+
<td align="center" colspan="4">English</td>
|
25 |
+
<td align="center" colspan="4">Chinese</td>
|
26 |
+
</tr>
|
27 |
+
<tr>
|
28 |
+
<td align="center">Conversation</td>
|
29 |
+
<td align="center">Detailed Description</td>
|
30 |
+
<td align="center">Complex Reasoning</td>
|
31 |
+
<td align="center">All</td>
|
32 |
+
<td align="center">Conversation</td>
|
33 |
+
<td align="center">Detailed Description</td>
|
34 |
+
<td align="center">Complex Reasoning</td>
|
35 |
+
<td align="center">All</td>
|
36 |
+
</tr>
|
37 |
+
<tr>
|
38 |
+
<td align="center" rowspan="3">English Model</td>
|
39 |
+
<td align="center">MiniGPT4</td>
|
40 |
+
<td align="center">65</td>
|
41 |
+
<td align="center">67.3</td>
|
42 |
+
<td align="center">76.6</td>
|
43 |
+
<td align="center">69.7</td>
|
44 |
+
<td align="center">-</td>
|
45 |
+
<td align="center">-</td>
|
46 |
+
<td align="center">-</td>
|
47 |
+
<td align="center">-</td>
|
48 |
+
</tr>
|
49 |
+
<tr>
|
50 |
+
<td align="center">InstructBLIP</td>
|
51 |
+
<td align="center">81.9</td>
|
52 |
+
<td align="center">68</td>
|
53 |
+
<td align="center">91.2</td>
|
54 |
+
<td align="center">80.5</td>
|
55 |
+
<td align="center">-</td>
|
56 |
+
<td align="center">-</td>
|
57 |
+
<td align="center">-</td>
|
58 |
+
<td align="center">-</td>
|
59 |
+
</tr>
|
60 |
+
<tr>
|
61 |
+
<td align="center">LLaVA</td>
|
62 |
+
<td align="center">89.5</td>
|
63 |
+
<td align="center">70.4</td>
|
64 |
+
<td align="center">96.2</td>
|
65 |
+
<td align="center">85.6</td>
|
66 |
+
<td align="center">-</td>
|
67 |
+
<td align="center">-</td>
|
68 |
+
<td align="center">-</td>
|
69 |
+
<td align="center">-</td>
|
70 |
+
</tr>
|
71 |
+
<tr>
|
72 |
+
<td align="center" rowspan="4">En-Zh Bilingual Model</td>
|
73 |
+
<td align="center">mPLUG-Owl </td>
|
74 |
+
<td align="center">64.6</td>
|
75 |
+
<td align="center">47.7</td>
|
76 |
+
<td align="center">80.1</td>
|
77 |
+
<td align="center">64.2</td>
|
78 |
+
<td align="center">76.3</td>
|
79 |
+
<td align="center">61.2</td>
|
80 |
+
<td align="center">77.8</td>
|
81 |
+
<td align="center">72</td>
|
82 |
+
</tr>
|
83 |
+
<tr>
|
84 |
+
<td align="center">VisualGLM</td>
|
85 |
+
<td align="center">62.4</td>
|
86 |
+
<td align="center">63</td>
|
87 |
+
<td align="center">80.6</td>
|
88 |
+
<td align="center">68.7</td>
|
89 |
+
<td align="center">76.6</td>
|
90 |
+
<td align="center">87.8</td>
|
91 |
+
<td align="center">83.6</td>
|
92 |
+
<td align="center">82.7</td>
|
93 |
+
</tr>
|
94 |
+
<tr>
|
95 |
+
<td align="center">Ziya (LLaMA 13B)</td>
|
96 |
+
<td align="center">82.7</td>
|
97 |
+
<td align="center">69.9</td>
|
98 |
+
<td align="center">92.1</td>
|
99 |
+
<td align="center">81.7</td>
|
100 |
+
<td align="center">85</td>
|
101 |
+
<td align="center">74.7</td>
|
102 |
+
<td align="center">82.4</td>
|
103 |
+
<td align="center">80.8</td>
|
104 |
+
</tr>
|
105 |
+
<tr>
|
106 |
+
<td align="center">VisCPM-Chat</td>
|
107 |
+
<td align="center">83.3</td>
|
108 |
+
<td align="center">68.9</td>
|
109 |
+
<td align="center">90.5</td>
|
110 |
+
<td align="center">81.1</td>
|
111 |
+
<td align="center">92.7</td>
|
112 |
+
<td align="center">76.1</td>
|
113 |
+
<td align="center">89.2</td>
|
114 |
+
<td align="center">86.3</td>
|
115 |
+
</tr>
|
116 |
+
</table>
|
117 |
+
|
118 |
+
# Install
|
119 |
+
|
120 |
+
1. Clone this repository and navigate to source folder
|
121 |
+
```bash
|
122 |
+
git clone <github repo URL>
|
123 |
+
cd viscpm
|
124 |
+
```
|
125 |
+
|
126 |
+
2. Install Package
|
127 |
+
```Shell
|
128 |
+
conda create -n viscpm python=3.10 -y
|
129 |
+
conda activate viscpm
|
130 |
+
pip install setuptools
|
131 |
+
pip install diffusers jieba matplotlib numpy opencv_python
|
132 |
+
pip install pandas Pillow psutil pydantic scipy
|
133 |
+
pip install torch==1.13.1 torchscale==0.2.0 torchvision==0.14.1 timm
|
134 |
+
pip install transformers==4.28.0
|
135 |
+
pip install tqdm typing_extensions
|
136 |
+
pip install git+https://github.com/thunlp/OpenDelta.git
|
137 |
+
pip install git+https://github.com/OpenBMB/CPM-Bee.git#egg=cpm-live&subdirectory=src
|
138 |
+
```
|
139 |
+
|
140 |
+
`VisCPM` require GPUs with more than 40GB memory. We will soon update more memory-friendly inference methods.
|
141 |
+
|
142 |
+
## How to use
|
143 |
+
|
144 |
+
```python
|
145 |
+
>>> from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
146 |
+
>>> from PIL import Image
|
147 |
+
|
148 |
+
>>> tokenizer = AutoTokenizer.from_pretrained('viscpm', trust_remote_code=True)
|
149 |
+
>>> processor = AutoImageProcessor.from_pretrained('viscpm', trust_remote_code=True)
|
150 |
+
>>> model = AutoModel.from_pretrained('viscpm', trust_remote_code=True).to('cuda')
|
151 |
+
|
152 |
+
>>> data = [{
|
153 |
+
>>> 'context': '',
|
154 |
+
>>> 'question': 'describe this image in detail.',
|
155 |
+
>>> 'image': tokenizer.unk_token * model.query_num,
|
156 |
+
>>> '<ans>': ''
|
157 |
+
>>> }]
|
158 |
+
>>> image = Image.open('case.jpg')
|
159 |
+
>>> result = model.generate(data, tokenizer, processor, image)
|
160 |
+
>>> print(result[0]['<ans>'])
|
161 |
+
่ฟๅน
ๅพ็ๆพ็คบไบไธ็พค็ญๆฐ็ๅจๅคฉ็ฉบไธญ้ฃ่กใ่ฟไบ็ญๆฐ็ๆผๆตฎๅจไธๅ็ๅฐๆน๏ผๅ
ๆฌๅฑฑ่ใๅๅธๅไนกๆๅฐๅบใ
|
162 |
+
```
|
beit3.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
|
4 |
+
# Copyright (c) 2023 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# --------------------------------------------------------'
|
7 |
+
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
|
12 |
+
from timm.models.registry import register_model
|
13 |
+
|
14 |
+
from torchscale.model.BEiT3 import BEiT3
|
15 |
+
from torchscale.architecture.config import EncoderConfig
|
16 |
+
|
17 |
+
|
18 |
+
def trunc_normal_(tensor, mean=0., std=1.):
|
19 |
+
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
20 |
+
|
21 |
+
|
22 |
+
def _get_base_config(
|
23 |
+
img_size=224, patch_size=16, drop_path_rate=0,
|
24 |
+
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
|
25 |
+
):
|
26 |
+
return EncoderConfig(
|
27 |
+
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
|
28 |
+
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
|
29 |
+
drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12,
|
30 |
+
encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12,
|
31 |
+
checkpoint_activations=checkpoint_activations,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def _get_large_config(
|
36 |
+
img_size=224, patch_size=16, drop_path_rate=0,
|
37 |
+
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
|
38 |
+
):
|
39 |
+
return EncoderConfig(
|
40 |
+
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
|
41 |
+
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
|
42 |
+
drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16,
|
43 |
+
encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24,
|
44 |
+
checkpoint_activations=checkpoint_activations,
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
class BEiT3Wrapper(nn.Module):
|
49 |
+
def __init__(self, args, **kwargs):
|
50 |
+
super().__init__()
|
51 |
+
self.args = args
|
52 |
+
self.beit3 = BEiT3(args)
|
53 |
+
self.apply(self._init_weights)
|
54 |
+
self.mim_head = nn.Linear(1024, 8192)
|
55 |
+
self.num_img_patches = self.beit3.vision_embed.num_position_embeddings()
|
56 |
+
self.hidden_size = args.encoder_embed_dim
|
57 |
+
|
58 |
+
def fix_init_weight(self):
|
59 |
+
def rescale(param, layer_id):
|
60 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
61 |
+
|
62 |
+
for layer_id, layer in enumerate(self.blocks):
|
63 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
64 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
65 |
+
|
66 |
+
def get_num_layers(self):
|
67 |
+
return self.beit3.encoder.num_layers
|
68 |
+
|
69 |
+
@torch.jit.ignore
|
70 |
+
def no_weight_decay(self):
|
71 |
+
return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}
|
72 |
+
|
73 |
+
def _init_weights(self, m):
|
74 |
+
if isinstance(m, nn.Linear):
|
75 |
+
trunc_normal_(m.weight, std=.02)
|
76 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
77 |
+
nn.init.constant_(m.bias, 0)
|
78 |
+
elif isinstance(m, nn.LayerNorm):
|
79 |
+
nn.init.constant_(m.bias, 0)
|
80 |
+
nn.init.constant_(m.weight, 1.0)
|
81 |
+
|
82 |
+
def forward(self, pixel_values, query_embed=None):
|
83 |
+
B = pixel_values.size(0)
|
84 |
+
dtype = self.beit3.vision_embed.proj.weight.dtype
|
85 |
+
pixel_values = pixel_values.to(dtype)
|
86 |
+
token_embeddings = self.beit3.vision_embed(pixel_values)
|
87 |
+
multiway_split_position = -1
|
88 |
+
if query_embed is not None:
|
89 |
+
query_embed = torch.stack([query_embed] * B)
|
90 |
+
multiway_split_position = token_embeddings.size(1)
|
91 |
+
token_embeddings = torch.cat([token_embeddings, query_embed], dim=1)
|
92 |
+
|
93 |
+
outputs = self.beit3.encoder(
|
94 |
+
src_tokens=None,
|
95 |
+
token_embeddings=token_embeddings,
|
96 |
+
multiway_split_position=multiway_split_position
|
97 |
+
)
|
98 |
+
vision_hidden_states = outputs["encoder_out"]
|
99 |
+
if query_embed is not None:
|
100 |
+
vision_hidden_states = vision_hidden_states[:, self.num_img_patches:]
|
101 |
+
return vision_hidden_states
|
102 |
+
|
103 |
+
|
104 |
+
@register_model
|
105 |
+
def beit3_large_patch16_224(pretrained=False, **kwargs):
|
106 |
+
args = _get_large_config(img_size=224, **kwargs)
|
107 |
+
model = BEiT3Wrapper(args, **kwargs)
|
108 |
+
return model
|
config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"_name_or_path": "openbmb/viscpmchat-bee-10b",
|
4 |
+
"architectures": [
|
5 |
+
"VisCpmBeeForCausalLM"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_viscpmchatbee.VisCpmChatBeeConfig",
|
9 |
+
"AutoModel": "modeling_cpmbee.VisCpmBeeForCausalLM",
|
10 |
+
"AutoModelForCausalLM": "modeling_cpmbee.VisCpmBeeForCausalLM"
|
11 |
+
},
|
12 |
+
"vocab_size": 86583,
|
13 |
+
"hidden_size": 4096,
|
14 |
+
"dim_ff" : 10240,
|
15 |
+
"num_hidden_layers" : 48,
|
16 |
+
"num_attention_heads": 32,
|
17 |
+
"dim_head" : 128,
|
18 |
+
"dropout_p" : 0.0,
|
19 |
+
"position_bias_num_buckets" : 256,
|
20 |
+
"position_bias_num_segment_buckets": 256,
|
21 |
+
"position_bias_max_distance" : 2048,
|
22 |
+
"vision_dim": 1024,
|
23 |
+
"query_num": 64,
|
24 |
+
"eps" : 1e-6,
|
25 |
+
"half" : true,
|
26 |
+
"model_type": "viscpmchatbee"
|
27 |
+
}
|
configuration_viscpmchatbee.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" CpmBee model configuration"""
|
16 |
+
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
CPMBEE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
26 |
+
"openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/resolve/main/config.json",
|
27 |
+
# See all VisCpmBee models at https://huggingface.co/models?filter=viscpmbee
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class VisCpmChatBeeConfig(PretrainedConfig):
|
32 |
+
r"""
|
33 |
+
This is the configuration class to store the configuration of a [`CpmBeeModel`]. It is used to instbeeiate an
|
34 |
+
CPMBee model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
35 |
+
with the defaults will yield a similar configuration to that of the CPMBee
|
36 |
+
[openbmb/cpm-bee-10b](https://huggingface.co/openbmb/cpm-bee-10b) architecture.
|
37 |
+
|
38 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
39 |
+
documentation from [`PretrainedConfig`] for more information.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 30720):
|
43 |
+
Vocabulary size of the CPMBee model. Defines the number of different tokens that can be represented by the
|
44 |
+
`input` passed when calling [`CpmBeeModel`].
|
45 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
46 |
+
Dimension of the encoder layers.
|
47 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
48 |
+
Number of attention heads in the Transformer encoder.
|
49 |
+
dim_head (`int`, *optional*, defaults to 128):
|
50 |
+
Dimension of attention heads for each attention layer in the Transformer encoder.
|
51 |
+
dim_ff (`int`, *optional*, defaults to 10240):
|
52 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
53 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
54 |
+
Number of layers of the Transformer encoder.
|
55 |
+
dropout_p (`float`, *optional*, defaults to 0.1):
|
56 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
|
57 |
+
position_bias_num_buckets (`int`, *optional*, defaults to 512):
|
58 |
+
The number of position_bias buckets.
|
59 |
+
position_bias_num_segment_buckets (`int`, *optional*, defaults to 32):
|
60 |
+
The number of segment buckets.
|
61 |
+
position_bias_max_distance (`int`, *optional*, defaults to 2048):
|
62 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
63 |
+
just in case (e.g., 512 or 1024 or 2048).
|
64 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
65 |
+
The epsilon used by the layer normalization layers.
|
66 |
+
init_std (`float`, *optional*, defaults to 1.0):
|
67 |
+
Initialize parameters with std = init_std.
|
68 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
69 |
+
Whether to use cache.
|
70 |
+
distance_scale (`float` or `int`, *optional*, defaults to 16):
|
71 |
+
Scale the rotary embedding.
|
72 |
+
mask_modules (`list` or `tuple`, *optional*, defaults to None):
|
73 |
+
Decides which feedforward block or attention block is pruned.
|
74 |
+
half (`bool`, *optional*, defaults to `False`):
|
75 |
+
Decides the model parameters are half-precision or not.
|
76 |
+
|
77 |
+
Example:
|
78 |
+
|
79 |
+
```python
|
80 |
+
>>> from transformers import CpmBeeModel, CpmBeeConfig
|
81 |
+
|
82 |
+
>>> # Initializing a CPMBee cpm-bee-10b style configuration
|
83 |
+
>>> configuration = CpmBeeConfig()
|
84 |
+
|
85 |
+
>>> # Initializing a model from the cpm-bee-10b style configuration
|
86 |
+
>>> model = CpmBeeModel(configuration)
|
87 |
+
|
88 |
+
>>> # Accessing the model configuration
|
89 |
+
>>> configuration = model.config
|
90 |
+
```"""
|
91 |
+
model_type = "viscpmchatbee"
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
vocab_size: int = 30720,
|
96 |
+
hidden_size: int = 4096,
|
97 |
+
num_attention_heads: int = 64,
|
98 |
+
dim_head: int = 64,
|
99 |
+
dim_ff: int = 10240,
|
100 |
+
num_hidden_layers: int = 32,
|
101 |
+
dropout_p: int = 0.0,
|
102 |
+
position_bias_num_buckets: int = 256,
|
103 |
+
position_bias_num_segment_buckets: int = 32,
|
104 |
+
position_bias_max_distance: int = 2048,
|
105 |
+
eps: int = 1e-6,
|
106 |
+
init_std: float = 1.0,
|
107 |
+
use_cache: bool = True,
|
108 |
+
distance_scale: Union[int, float] = 16,
|
109 |
+
mask_modules: Optional[Union[List, Tuple]] = None,
|
110 |
+
half: bool = False,
|
111 |
+
vision_dim: int = 1024,
|
112 |
+
query_num: int = 64,
|
113 |
+
**kwargs,
|
114 |
+
):
|
115 |
+
super().__init__(**kwargs)
|
116 |
+
self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
|
117 |
+
self.hidden_size = hidden_size
|
118 |
+
self.num_attention_heads = num_attention_heads
|
119 |
+
self.dim_head = dim_head
|
120 |
+
self.dim_ff = dim_ff
|
121 |
+
self.num_hidden_layers = num_hidden_layers
|
122 |
+
self.position_bias_num_buckets = position_bias_num_buckets
|
123 |
+
self.position_bias_max_distance = position_bias_max_distance
|
124 |
+
self.dropout_p = dropout_p
|
125 |
+
self.eps = eps
|
126 |
+
self.use_cache = use_cache
|
127 |
+
self.vocab_size = vocab_size
|
128 |
+
self.init_std = init_std
|
129 |
+
self.distance_scale = distance_scale
|
130 |
+
self.half = half
|
131 |
+
self.mask_modules = mask_modules
|
132 |
+
self.vision_dim = vision_dim
|
133 |
+
self.query_num = query_num
|
feature_extraction_viscpmchatbee.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
from transformers.utils import logging
|
4 |
+
from processing_viscpmchatbee import VisCpmChatBeeImageProcessor
|
5 |
+
|
6 |
+
|
7 |
+
logger = logging.get_logger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class VisCpmChatBeeFeatureExtractor(VisCpmChatBeeImageProcessor):
|
11 |
+
def __init__(self, *args, **kwargs) -> None:
|
12 |
+
warnings.warn(
|
13 |
+
"The class VisCpmBeeFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
|
14 |
+
" use CLIPImageProcessor instead.",
|
15 |
+
FutureWarning,
|
16 |
+
)
|
17 |
+
super().__init__(*args, **kwargs)
|
generation_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_beams": 3,
|
3 |
+
"num_beam_groups": 1,
|
4 |
+
"do_sample": false,
|
5 |
+
"is_constraint_gen_mode": false,
|
6 |
+
"is_contrastive_search_gen_mode": false,
|
7 |
+
"pad_token_id": 0,
|
8 |
+
"eos_token_id": 7,
|
9 |
+
"bos_token_id": 6,
|
10 |
+
"max_new_tokens": 100,
|
11 |
+
"vocab_size": 86583
|
12 |
+
}
|
modeling_cpmbee.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
preprocessor_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"image_processor_type": "VisCpmChatBeeImageProcessor",
|
3 |
+
"is_train": false,
|
4 |
+
"randaug": false,
|
5 |
+
"input_size": 224,
|
6 |
+
"interpolation": "bicubic",
|
7 |
+
"auto_map": {
|
8 |
+
"AutoImageProcessor": "processing_viscpmchatbee.VisCpmChatBeeImageProcessor"
|
9 |
+
}
|
10 |
+
}
|
processing_viscpmchatbee.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
5 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
6 |
+
from torchvision import transforms
|
7 |
+
import urllib
|
8 |
+
from tqdm import tqdm
|
9 |
+
from cpm_live.tokenizers import CPMBeeTokenizer
|
10 |
+
from torch.utils.data import default_collate
|
11 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
12 |
+
from typing_extensions import TypedDict
|
13 |
+
from numpy.typing import NDArray
|
14 |
+
import importlib.machinery
|
15 |
+
import importlib.util
|
16 |
+
import types
|
17 |
+
import random
|
18 |
+
from transformers.image_utils import make_list_of_images
|
19 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
20 |
+
from transformers import TensorType
|
21 |
+
import json
|
22 |
+
|
23 |
+
|
24 |
+
# aug functions
|
25 |
+
def identity_func(img):
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
def autocontrast_func(img, cutoff=0):
|
30 |
+
'''
|
31 |
+
same output as PIL.ImageOps.autocontrast
|
32 |
+
'''
|
33 |
+
n_bins = 256
|
34 |
+
|
35 |
+
def tune_channel(ch):
|
36 |
+
n = ch.size
|
37 |
+
cut = cutoff * n // 100
|
38 |
+
if cut == 0:
|
39 |
+
high, low = ch.max(), ch.min()
|
40 |
+
else:
|
41 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
42 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
43 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
44 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
45 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
46 |
+
if high <= low:
|
47 |
+
table = np.arange(n_bins)
|
48 |
+
else:
|
49 |
+
scale = (n_bins - 1) / (high - low)
|
50 |
+
table = np.arange(n_bins) * scale - low * scale
|
51 |
+
table[table < 0] = 0
|
52 |
+
table[table > n_bins - 1] = n_bins - 1
|
53 |
+
table = table.clip(0, 255).astype(np.uint8)
|
54 |
+
return table[ch]
|
55 |
+
|
56 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
57 |
+
out = cv2.merge(channels)
|
58 |
+
return out
|
59 |
+
|
60 |
+
|
61 |
+
def equalize_func(img):
|
62 |
+
'''
|
63 |
+
same output as PIL.ImageOps.equalize
|
64 |
+
PIL's implementation is different from cv2.equalize
|
65 |
+
'''
|
66 |
+
n_bins = 256
|
67 |
+
|
68 |
+
def tune_channel(ch):
|
69 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
70 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
71 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
72 |
+
if step == 0:
|
73 |
+
return ch
|
74 |
+
n = np.empty_like(hist)
|
75 |
+
n[0] = step // 2
|
76 |
+
n[1:] = hist[:-1]
|
77 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
78 |
+
return table[ch]
|
79 |
+
|
80 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
81 |
+
out = cv2.merge(channels)
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
86 |
+
'''
|
87 |
+
like PIL, rotate by degree, not radians
|
88 |
+
'''
|
89 |
+
H, W = img.shape[0], img.shape[1]
|
90 |
+
center = W / 2, H / 2
|
91 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
92 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
def solarize_func(img, thresh=128):
|
97 |
+
'''
|
98 |
+
same output as PIL.ImageOps.posterize
|
99 |
+
'''
|
100 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
101 |
+
table = table.clip(0, 255).astype(np.uint8)
|
102 |
+
out = table[img]
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
def color_func(img, factor):
|
107 |
+
'''
|
108 |
+
same output as PIL.ImageEnhance.Color
|
109 |
+
'''
|
110 |
+
# implementation according to PIL definition, quite slow
|
111 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
112 |
+
# out = blend(degenerate, img, factor)
|
113 |
+
# M = (
|
114 |
+
# np.eye(3) * factor
|
115 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
116 |
+
# )[np.newaxis, np.newaxis, :]
|
117 |
+
M = (
|
118 |
+
np.float32([
|
119 |
+
[0.886, -0.114, -0.114],
|
120 |
+
[-0.587, 0.413, -0.587],
|
121 |
+
[-0.299, -0.299, 0.701]]) * factor
|
122 |
+
+ np.float32([[0.114], [0.587], [0.299]])
|
123 |
+
)
|
124 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
125 |
+
return out
|
126 |
+
|
127 |
+
|
128 |
+
def contrast_func(img, factor):
|
129 |
+
"""
|
130 |
+
same output as PIL.ImageEnhance.Contrast
|
131 |
+
"""
|
132 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
133 |
+
table = np.array([(
|
134 |
+
el - mean) * factor + mean
|
135 |
+
for el in range(256)
|
136 |
+
]).clip(0, 255).astype(np.uint8)
|
137 |
+
out = table[img]
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def brightness_func(img, factor):
|
142 |
+
'''
|
143 |
+
same output as PIL.ImageEnhance.Contrast
|
144 |
+
'''
|
145 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
146 |
+
out = table[img]
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def sharpness_func(img, factor):
|
151 |
+
'''
|
152 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
153 |
+
areas are same
|
154 |
+
'''
|
155 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
156 |
+
kernel[1][1] = 5
|
157 |
+
kernel /= 13
|
158 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
159 |
+
if factor == 0.0:
|
160 |
+
out = degenerate
|
161 |
+
elif factor == 1.0:
|
162 |
+
out = img
|
163 |
+
else:
|
164 |
+
out = img.astype(np.float32)
|
165 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
166 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
167 |
+
out = out.astype(np.uint8)
|
168 |
+
return out
|
169 |
+
|
170 |
+
|
171 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
172 |
+
H, W = img.shape[0], img.shape[1]
|
173 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
174 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
175 |
+
return out
|
176 |
+
|
177 |
+
|
178 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
179 |
+
'''
|
180 |
+
same output as PIL.Image.transform
|
181 |
+
'''
|
182 |
+
H, W = img.shape[0], img.shape[1]
|
183 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
184 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
185 |
+
return out
|
186 |
+
|
187 |
+
|
188 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
189 |
+
'''
|
190 |
+
same output as PIL.Image.transform
|
191 |
+
'''
|
192 |
+
H, W = img.shape[0], img.shape[1]
|
193 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
194 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
195 |
+
return out
|
196 |
+
|
197 |
+
|
198 |
+
def posterize_func(img, bits):
|
199 |
+
'''
|
200 |
+
same output as PIL.ImageOps.posterize
|
201 |
+
'''
|
202 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
203 |
+
return out
|
204 |
+
|
205 |
+
|
206 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
207 |
+
H, W = img.shape[0], img.shape[1]
|
208 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
209 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
210 |
+
return out
|
211 |
+
|
212 |
+
|
213 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
214 |
+
replace = np.array(replace, dtype=np.uint8)
|
215 |
+
H, W = img.shape[0], img.shape[1]
|
216 |
+
rh, rw = np.random.random(2)
|
217 |
+
pad_size = pad_size // 2
|
218 |
+
ch, cw = int(rh * H), int(rw * W)
|
219 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
220 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
221 |
+
out = img.copy()
|
222 |
+
out[x1:x2, y1:y2, :] = replace
|
223 |
+
return out
|
224 |
+
|
225 |
+
|
226 |
+
# level to args
|
227 |
+
def enhance_level_to_args(MAX_LEVEL):
|
228 |
+
def level_to_args(level):
|
229 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
230 |
+
return level_to_args
|
231 |
+
|
232 |
+
|
233 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
234 |
+
def level_to_args(level):
|
235 |
+
level = (level / MAX_LEVEL) * 0.3
|
236 |
+
if np.random.random() > 0.5:
|
237 |
+
level = -level
|
238 |
+
return (level, replace_value)
|
239 |
+
|
240 |
+
return level_to_args
|
241 |
+
|
242 |
+
|
243 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
244 |
+
def level_to_args(level):
|
245 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
246 |
+
if np.random.random() > 0.5:
|
247 |
+
level = -level
|
248 |
+
return (level, replace_value)
|
249 |
+
|
250 |
+
return level_to_args
|
251 |
+
|
252 |
+
|
253 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
254 |
+
def level_to_args(level):
|
255 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
256 |
+
return (level, replace_value)
|
257 |
+
|
258 |
+
return level_to_args
|
259 |
+
|
260 |
+
|
261 |
+
def solarize_level_to_args(MAX_LEVEL):
|
262 |
+
def level_to_args(level):
|
263 |
+
level = int((level / MAX_LEVEL) * 256)
|
264 |
+
return (level, )
|
265 |
+
return level_to_args
|
266 |
+
|
267 |
+
|
268 |
+
def none_level_to_args(level):
|
269 |
+
return ()
|
270 |
+
|
271 |
+
|
272 |
+
def posterize_level_to_args(MAX_LEVEL):
|
273 |
+
def level_to_args(level):
|
274 |
+
level = int((level / MAX_LEVEL) * 4)
|
275 |
+
return (level, )
|
276 |
+
return level_to_args
|
277 |
+
|
278 |
+
|
279 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
280 |
+
def level_to_args(level):
|
281 |
+
level = (level / MAX_LEVEL) * 30
|
282 |
+
if np.random.random() < 0.5:
|
283 |
+
level = -level
|
284 |
+
return (level, replace_value)
|
285 |
+
|
286 |
+
return level_to_args
|
287 |
+
|
288 |
+
|
289 |
+
func_dict = {
|
290 |
+
'Identity': identity_func,
|
291 |
+
'AutoContrast': autocontrast_func,
|
292 |
+
'Equalize': equalize_func,
|
293 |
+
'Rotate': rotate_func,
|
294 |
+
'Solarize': solarize_func,
|
295 |
+
'Color': color_func,
|
296 |
+
'Contrast': contrast_func,
|
297 |
+
'Brightness': brightness_func,
|
298 |
+
'Sharpness': sharpness_func,
|
299 |
+
'ShearX': shear_x_func,
|
300 |
+
'TranslateX': translate_x_func,
|
301 |
+
'TranslateY': translate_y_func,
|
302 |
+
'Posterize': posterize_func,
|
303 |
+
'ShearY': shear_y_func,
|
304 |
+
}
|
305 |
+
|
306 |
+
translate_const = 10
|
307 |
+
MAX_LEVEL = 10
|
308 |
+
replace_value = (128, 128, 128)
|
309 |
+
arg_dict = {
|
310 |
+
'Identity': none_level_to_args,
|
311 |
+
'AutoContrast': none_level_to_args,
|
312 |
+
'Equalize': none_level_to_args,
|
313 |
+
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
314 |
+
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
315 |
+
'Color': enhance_level_to_args(MAX_LEVEL),
|
316 |
+
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
317 |
+
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
318 |
+
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
319 |
+
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
320 |
+
'TranslateX': translate_level_to_args(
|
321 |
+
translate_const, MAX_LEVEL, replace_value
|
322 |
+
),
|
323 |
+
'TranslateY': translate_level_to_args(
|
324 |
+
translate_const, MAX_LEVEL, replace_value
|
325 |
+
),
|
326 |
+
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
327 |
+
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
328 |
+
}
|
329 |
+
|
330 |
+
|
331 |
+
class RandomAugment(object):
|
332 |
+
|
333 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
334 |
+
self.N = N
|
335 |
+
self.M = M
|
336 |
+
self.isPIL = isPIL
|
337 |
+
if augs:
|
338 |
+
self.augs = augs
|
339 |
+
else:
|
340 |
+
self.augs = list(arg_dict.keys())
|
341 |
+
|
342 |
+
def get_random_ops(self):
|
343 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
344 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
345 |
+
|
346 |
+
def __call__(self, img):
|
347 |
+
if self.isPIL:
|
348 |
+
img = np.array(img)
|
349 |
+
ops = self.get_random_ops()
|
350 |
+
for name, prob, level in ops:
|
351 |
+
if np.random.random() > prob:
|
352 |
+
continue
|
353 |
+
args = arg_dict[name](level)
|
354 |
+
img = func_dict[name](img, *args)
|
355 |
+
return img
|
356 |
+
|
357 |
+
|
358 |
+
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'):
|
359 |
+
if is_train:
|
360 |
+
t = [
|
361 |
+
RandomResizedCropAndInterpolation(
|
362 |
+
input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
|
363 |
+
transforms.RandomHorizontalFlip(),
|
364 |
+
]
|
365 |
+
if randaug:
|
366 |
+
t.append(
|
367 |
+
RandomAugment(
|
368 |
+
2, 7, isPIL=True,
|
369 |
+
augs=[
|
370 |
+
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
371 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
372 |
+
]))
|
373 |
+
t += [
|
374 |
+
transforms.ToTensor(),
|
375 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
376 |
+
]
|
377 |
+
t = transforms.Compose(t)
|
378 |
+
else:
|
379 |
+
t = transforms.Compose([
|
380 |
+
transforms.Resize((input_size, input_size),
|
381 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
382 |
+
transforms.ToTensor(),
|
383 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
|
384 |
+
])
|
385 |
+
|
386 |
+
return t
|
387 |
+
|
388 |
+
|
389 |
+
class VisCpmChatBeeImageProcessor(BaseImageProcessor):
|
390 |
+
def __init__(self, is_train, randaug=True, input_size=224, interpolation='bicubic', **kwargs):
|
391 |
+
super().__init__(**kwargs)
|
392 |
+
self.is_train = is_train
|
393 |
+
self.randaug = randaug
|
394 |
+
self.input_size = input_size
|
395 |
+
self.interpolation = interpolation
|
396 |
+
self._transform = build_transform(is_train, randaug=randaug, input_size=input_size, interpolation=interpolation)
|
397 |
+
|
398 |
+
def preprocess(self, images, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs) -> BatchFeature:
|
399 |
+
images = make_list_of_images(images)
|
400 |
+
images = [self._transform(image) for image in images]
|
401 |
+
images = torch.tensor([image.numpy() for image in images])
|
402 |
+
|
403 |
+
data = {"pixel_values": images}
|
404 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
405 |
+
|
406 |
+
def to_json_string(self) -> str:
|
407 |
+
"""
|
408 |
+
Serializes this instance to a JSON string.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
`str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
|
412 |
+
"""
|
413 |
+
dictionary = self.to_dict()
|
414 |
+
|
415 |
+
for key, value in dictionary.items():
|
416 |
+
if isinstance(value, np.ndarray):
|
417 |
+
dictionary[key] = value.tolist()
|
418 |
+
|
419 |
+
# make sure private name "_processor_class" is correctly
|
420 |
+
# saved as "processor_class"
|
421 |
+
_processor_class = dictionary.pop("_processor_class", None)
|
422 |
+
if _processor_class is not None:
|
423 |
+
dictionary["processor_class"] = _processor_class
|
424 |
+
_transform = dictionary.pop("_transform", None)
|
425 |
+
if _transform is not None:
|
426 |
+
dictionary["_transform"] = str(type(_transform))
|
427 |
+
|
428 |
+
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
tokenization_viscpmchatbee.py
ADDED
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for CpmBee."""
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from numpy.typing import NDArray
|
22 |
+
from typing_extensions import TypedDict
|
23 |
+
|
24 |
+
from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
|
25 |
+
from transformers.tokenization_utils_base import AddedToken, BatchEncoding, TextInput, TruncationStrategy
|
26 |
+
from transformers.utils import logging
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
32 |
+
|
33 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
34 |
+
"vocab_file": {
|
35 |
+
"openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/blob/main/vocab.txt",
|
36 |
+
},
|
37 |
+
}
|
38 |
+
|
39 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
40 |
+
"openbmb/viscpmchat-bee-10b": 4096,
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
class _PrevExtTableStates(TypedDict):
|
45 |
+
ext_table: Dict[int, str]
|
46 |
+
token_id_table: Dict[str, Dict[int, int]]
|
47 |
+
|
48 |
+
|
49 |
+
CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
|
50 |
+
|
51 |
+
|
52 |
+
def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
|
53 |
+
ret = n_up * max_depth + n_down
|
54 |
+
if ret == 0:
|
55 |
+
return ret
|
56 |
+
else:
|
57 |
+
# bucket 1 is reserved for incontext samples
|
58 |
+
return ret + 1
|
59 |
+
|
60 |
+
|
61 |
+
class _DictTree(TypedDict):
|
62 |
+
value: str
|
63 |
+
children: List["_DictTree"]
|
64 |
+
depth: int
|
65 |
+
segment_id: int
|
66 |
+
need_predict: bool
|
67 |
+
is_image: bool
|
68 |
+
|
69 |
+
|
70 |
+
class VisCpmChatBeeTokenizer(PreTrainedTokenizer):
|
71 |
+
"""
|
72 |
+
Construct a CPMBee tokenizer.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
vocab_file (`str`):
|
76 |
+
Path to the vocabulary file.
|
77 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
78 |
+
The beginning of sequence token.
|
79 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
80 |
+
The end of sequence token.
|
81 |
+
line_token (`str`, *optional*, defaults to `"\n"`):
|
82 |
+
The line token.
|
83 |
+
space_token (`str`, *optional*, defaults to `" "`):
|
84 |
+
The space token.
|
85 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
86 |
+
The unknown token.
|
87 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
88 |
+
The mask token.
|
89 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
90 |
+
The token used for padding.
|
91 |
+
padding_side (`str`, *optional*, defaults to `"left"`):
|
92 |
+
The padding side. CPM-Bee will use left padding by default.
|
93 |
+
"""
|
94 |
+
|
95 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
96 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
97 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
98 |
+
model_input_names: List[str] = [
|
99 |
+
"input_ids",
|
100 |
+
"attention_mask",
|
101 |
+
"input_id_sub",
|
102 |
+
"position",
|
103 |
+
"context",
|
104 |
+
"sample_ids",
|
105 |
+
"num_segments",
|
106 |
+
"segment",
|
107 |
+
"segment_rel_offset",
|
108 |
+
"segment_rel",
|
109 |
+
]
|
110 |
+
add_prefix_space = False
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
vocab_file,
|
115 |
+
bos_token="<s>",
|
116 |
+
eos_token="</s>",
|
117 |
+
line_token="\n",
|
118 |
+
space_token=" ",
|
119 |
+
unk_token="<unk>",
|
120 |
+
mask_token="<mask>",
|
121 |
+
pad_token="<pad>",
|
122 |
+
padding_side="left",
|
123 |
+
**kwargs,
|
124 |
+
):
|
125 |
+
super().__init__(
|
126 |
+
bos_token=bos_token,
|
127 |
+
eos_token=eos_token,
|
128 |
+
line_token=line_token,
|
129 |
+
space_token=space_token,
|
130 |
+
unk_token=unk_token,
|
131 |
+
mask_token=mask_token,
|
132 |
+
pad_token=pad_token,
|
133 |
+
padding_side=padding_side,
|
134 |
+
**kwargs,
|
135 |
+
)
|
136 |
+
|
137 |
+
self.encoder: Dict[str, int] = {}
|
138 |
+
|
139 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
140 |
+
for token in reader.readlines():
|
141 |
+
token = token.rstrip("\n")
|
142 |
+
if len(token) == 0:
|
143 |
+
continue
|
144 |
+
self.encoder[token] = len(self.encoder)
|
145 |
+
|
146 |
+
self.encoder[" "] = self.encoder["</_>"]
|
147 |
+
self.encoder["\n"] = self.encoder["</n>"]
|
148 |
+
del self.encoder["</_>"]
|
149 |
+
del self.encoder["</n>"]
|
150 |
+
|
151 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
152 |
+
|
153 |
+
self._max_word_len = max([len(x) for x in self.encoder.keys()])
|
154 |
+
self.cpmbee_special_tokens = {k: v for k, v in self.encoder.items() if k.startswith("<") and k.endswith(">")}
|
155 |
+
|
156 |
+
self.ext_table: Dict[int, str] = {}
|
157 |
+
self.ext_table_rev: Dict[str, int] = {}
|
158 |
+
|
159 |
+
self.token_id_table: Dict[str, Dict[int, int]] = {}
|
160 |
+
self.ext_special_tokens = []
|
161 |
+
|
162 |
+
self.ext_args_for_model = [
|
163 |
+
"input_id_subs",
|
164 |
+
"input_pos",
|
165 |
+
"context",
|
166 |
+
"segment_ids",
|
167 |
+
"segment_rel_offset",
|
168 |
+
"segment_rel",
|
169 |
+
"sample_ids",
|
170 |
+
"num_segments",
|
171 |
+
"predict_segments",
|
172 |
+
"answer_placeholders",
|
173 |
+
"ext_table",
|
174 |
+
"token_id_table",
|
175 |
+
"image_bound"
|
176 |
+
]
|
177 |
+
|
178 |
+
@property
|
179 |
+
def bod_token_id(self):
|
180 |
+
return self.encoder[self.bod_token]
|
181 |
+
|
182 |
+
@property
|
183 |
+
def eod_token_id(self):
|
184 |
+
return self.encoder[self.eod_token]
|
185 |
+
|
186 |
+
@property
|
187 |
+
def newline_id(self):
|
188 |
+
return self.encoder[self.line_token]
|
189 |
+
|
190 |
+
@property
|
191 |
+
def vocab_size(self) -> int:
|
192 |
+
return len(self.encoder)
|
193 |
+
|
194 |
+
def __len__(self):
|
195 |
+
"""
|
196 |
+
Size of the full vocabulary with the added tokens.
|
197 |
+
"""
|
198 |
+
return self.vocab_size + len(self.added_tokens_encoder)
|
199 |
+
|
200 |
+
def get_vocab(self):
|
201 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
202 |
+
|
203 |
+
def get_piece(self, text: str) -> str:
|
204 |
+
"""
|
205 |
+
Match with maximum length.
|
206 |
+
"""
|
207 |
+
len_text = len(text)
|
208 |
+
for i in range(len(text)):
|
209 |
+
sub = text[: len_text - i]
|
210 |
+
if (sub in self.encoder) or (sub in self.added_tokens_encoder):
|
211 |
+
return sub
|
212 |
+
return text[0]
|
213 |
+
|
214 |
+
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
|
215 |
+
r"""
|
216 |
+
Override the `tokenize` to meet the needs of CPMBee:
|
217 |
+
1. Mark the special token with `<` and `>`. The `<>` will be ignored.
|
218 |
+
2. Split sentences by the marked special tokens.
|
219 |
+
3. Record the marked special token by `ext_table` and `ext_table_rev`.
|
220 |
+
4. Tokenize the sentence without special tokens.
|
221 |
+
"""
|
222 |
+
for_cpmbee = kwargs.get("for_cpmbee", False)
|
223 |
+
all_special_tokens_extended = {
|
224 |
+
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
|
225 |
+
}
|
226 |
+
|
227 |
+
sentence_split = [""]
|
228 |
+
is_special_token = False
|
229 |
+
for i, c in enumerate(text):
|
230 |
+
if is_special_token:
|
231 |
+
if c == "<":
|
232 |
+
tail = sentence_split.pop(-1)
|
233 |
+
sentence_split[-1] += tail
|
234 |
+
sentence_split.append(c)
|
235 |
+
elif c == ">":
|
236 |
+
# end of special token
|
237 |
+
sentence_split[-1] += c
|
238 |
+
if sentence_split[-1] == "<>":
|
239 |
+
continue
|
240 |
+
is_special_token = False
|
241 |
+
sentence_split.append("")
|
242 |
+
else:
|
243 |
+
sentence_split[-1] += c
|
244 |
+
else:
|
245 |
+
if c == "<":
|
246 |
+
is_special_token = True
|
247 |
+
sentence_split.append(c)
|
248 |
+
else:
|
249 |
+
sentence_split[-1] += c
|
250 |
+
if is_special_token:
|
251 |
+
tail = sentence_split.pop(-1)
|
252 |
+
sentence_split[-1] += tail
|
253 |
+
|
254 |
+
output_tokens = []
|
255 |
+
for i, part in enumerate(sentence_split):
|
256 |
+
if (i & 1) == 1:
|
257 |
+
# special token
|
258 |
+
output_tokens.append(part)
|
259 |
+
if for_cpmbee and (part not in self.encoder) and (part not in self.ext_table_rev):
|
260 |
+
self.ext_table_rev[part] = len(self.ext_table_rev) + self.vocab_size
|
261 |
+
self.ext_table[self.ext_table_rev[part]] = part
|
262 |
+
else:
|
263 |
+
output_tokens.extend(self._tokenize(part, for_cpmbee=for_cpmbee))
|
264 |
+
|
265 |
+
# drop spaces
|
266 |
+
for i, token in enumerate(output_tokens):
|
267 |
+
if token in self.added_tokens_encoder:
|
268 |
+
token = all_special_tokens_extended.get(token, None)
|
269 |
+
left = output_tokens[i - 1] if i > 0 else None
|
270 |
+
right = output_tokens[i + 1] if i < len(output_tokens) - 1 else None
|
271 |
+
if isinstance(token, AddedToken):
|
272 |
+
if token.rstrip and right:
|
273 |
+
# A bit counter-intuitive but we strip the left of the string
|
274 |
+
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
275 |
+
output_tokens[i + 1] = right.lstrip()
|
276 |
+
# Strip white spaces on the left
|
277 |
+
if token.lstrip and left:
|
278 |
+
output_tokens[i - 1] = left.rstrip() # Opposite here
|
279 |
+
else:
|
280 |
+
if right:
|
281 |
+
output_tokens[i + 1] = right.lstrip()
|
282 |
+
if left:
|
283 |
+
output_tokens[i - 1] = left.rstrip()
|
284 |
+
|
285 |
+
skipped_tokens = []
|
286 |
+
for token in output_tokens:
|
287 |
+
if not token:
|
288 |
+
continue
|
289 |
+
else:
|
290 |
+
skipped_tokens.append(token)
|
291 |
+
|
292 |
+
return skipped_tokens
|
293 |
+
|
294 |
+
def _tokenize(self, text, **kwargs):
|
295 |
+
"""
|
296 |
+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
297 |
+
vocabulary.
|
298 |
+
|
299 |
+
Do NOT take care of added tokens. Record the unk tokens and special tokens in `ext_table` and `ext_table_rev`.
|
300 |
+
"""
|
301 |
+
for_cpmbee = kwargs.get("for_cpmbee", False)
|
302 |
+
output_tokens = []
|
303 |
+
|
304 |
+
part_st = 0
|
305 |
+
last_unk = None
|
306 |
+
while part_st < len(text):
|
307 |
+
piece = self.get_piece(text[part_st:])
|
308 |
+
if piece in self.encoder or self.added_tokens_encoder:
|
309 |
+
if last_unk is None:
|
310 |
+
output_tokens.append(piece)
|
311 |
+
else:
|
312 |
+
if for_cpmbee and (last_unk not in self.ext_table_rev):
|
313 |
+
self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
|
314 |
+
self.ext_table[self.ext_table_rev[last_unk]] = last_unk
|
315 |
+
output_tokens.append(last_unk)
|
316 |
+
output_tokens.append(piece)
|
317 |
+
last_unk = None
|
318 |
+
else:
|
319 |
+
if last_unk is None:
|
320 |
+
last_unk = piece
|
321 |
+
else:
|
322 |
+
last_unk += piece
|
323 |
+
part_st += len(piece)
|
324 |
+
if last_unk is not None:
|
325 |
+
# part end with UNK
|
326 |
+
if for_cpmbee and (last_unk not in self.ext_table_rev):
|
327 |
+
self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
|
328 |
+
self.ext_table[self.ext_table_rev[last_unk]] = last_unk
|
329 |
+
output_tokens.append(last_unk)
|
330 |
+
|
331 |
+
return output_tokens
|
332 |
+
|
333 |
+
def check(self, token):
|
334 |
+
return token in self.encoder
|
335 |
+
|
336 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
337 |
+
return "".join(tokens)
|
338 |
+
|
339 |
+
def _convert_token_to_id(self, token: str):
|
340 |
+
"""Converts a token (str) in an id using the vocab and ext_table."""
|
341 |
+
if token in self.encoder:
|
342 |
+
return self.encoder.get(token)
|
343 |
+
elif token in self.ext_table_rev:
|
344 |
+
return self.ext_table_rev[token]
|
345 |
+
elif token in self.added_tokens_encoder:
|
346 |
+
return self.added_tokens_encoder[token]
|
347 |
+
else:
|
348 |
+
return self.unk_token_id
|
349 |
+
|
350 |
+
def _convert_id_to_token(self, index):
|
351 |
+
"""Converts an index (integer) in a token (str) using the vocab and ext_table."""
|
352 |
+
if index in self.ext_table:
|
353 |
+
return self.ext_table[index]
|
354 |
+
elif index in self.added_tokens_decoder:
|
355 |
+
return self.added_tokens_decoder[index]
|
356 |
+
else:
|
357 |
+
if index >= 0:
|
358 |
+
return self.decoder[index]
|
359 |
+
|
360 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
361 |
+
if os.path.isdir(save_directory):
|
362 |
+
vocab_file = os.path.join(
|
363 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
367 |
+
index = 0
|
368 |
+
self.encoder["</n>"] = self.encoder["\n"]
|
369 |
+
del self.encoder["\n"]
|
370 |
+
self.encoder["</_>"] = self.encoder[" "]
|
371 |
+
del self.encoder[" "]
|
372 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
373 |
+
for token, token_index in sorted(self.encoder.items(), key=lambda x: x[1]):
|
374 |
+
if index != token_index:
|
375 |
+
logger.warning(
|
376 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
377 |
+
" Please check that the vocabulary is not corrupted!"
|
378 |
+
)
|
379 |
+
index = token_index
|
380 |
+
writer.write(token + "\n")
|
381 |
+
index += 1
|
382 |
+
return (vocab_file,)
|
383 |
+
|
384 |
+
def __call__(self, text, *args, **kwargs):
|
385 |
+
r"""
|
386 |
+
CPMBee `call` method will use `_tokenize_cpmbee` when the input type is dict.
|
387 |
+
"""
|
388 |
+
if isinstance(text, dict):
|
389 |
+
return self._batch_tokenize_cpmbee([text], *args, **kwargs)
|
390 |
+
elif isinstance(text, (list, tuple)):
|
391 |
+
if isinstance(text[0], dict):
|
392 |
+
return self._batch_tokenize_cpmbee(text, *args, **kwargs)
|
393 |
+
else:
|
394 |
+
return super().__call__(text, *args, **kwargs)
|
395 |
+
else:
|
396 |
+
return super().__call__(text, *args, **kwargs)
|
397 |
+
|
398 |
+
# ๅ่ฏ
|
399 |
+
def _tokenize_cpmbee(self, data: TextInput, *args, **kwargs) -> List[str]:
|
400 |
+
"""
|
401 |
+
A tokenize method to process dict data. Exclusive for CPMBee.
|
402 |
+
"""
|
403 |
+
if isinstance(data, str):
|
404 |
+
data = json.loads(data)
|
405 |
+
if not isinstance(data, Dict):
|
406 |
+
raise TypeError(
|
407 |
+
"CpmBeeTokenizer input data should be dict or str in dict format, but got {}".format(type(data))
|
408 |
+
)
|
409 |
+
|
410 |
+
# 1. prepare answer placeholder
|
411 |
+
answer_placeholders = []
|
412 |
+
|
413 |
+
def _put_placeholder(data: Any, path: List[str] = []):
|
414 |
+
if isinstance(data, dict):
|
415 |
+
ret = {}
|
416 |
+
for k, v in data.items():
|
417 |
+
ret[k] = _put_placeholder(v, path + [k])
|
418 |
+
return ret
|
419 |
+
else:
|
420 |
+
answer_placeholders.append(path)
|
421 |
+
return "<ans_{}>".format(len(answer_placeholders))
|
422 |
+
|
423 |
+
data["<ans>"] = _put_placeholder(data["<ans>"])
|
424 |
+
|
425 |
+
(
|
426 |
+
input_ids,
|
427 |
+
input_id_subs,
|
428 |
+
context,
|
429 |
+
segment_ids,
|
430 |
+
segment_rel,
|
431 |
+
n_segments,
|
432 |
+
table_states,
|
433 |
+
image_bound
|
434 |
+
) = self.convert_data_to_id(data, shuffle_answer=False, max_depth=8)
|
435 |
+
|
436 |
+
# <ans> mapping from sub to id
|
437 |
+
sub_ans_map: Dict[int, int] = {}
|
438 |
+
for fake_id, token_sub in table_states["token_id_table"]["<ans>"].items():
|
439 |
+
token = table_states["ext_table"][fake_id]
|
440 |
+
if token.startswith("<ans_") and token.endswith(">"):
|
441 |
+
ans_id = int(token[5:-1])
|
442 |
+
sub_ans_map[token_sub] = ans_id
|
443 |
+
|
444 |
+
tmp_input_ids = []
|
445 |
+
tmp_input_sub = []
|
446 |
+
tmp_input_seg = []
|
447 |
+
|
448 |
+
# get predict segments
|
449 |
+
predict_segments: List[Tuple[int, int]] = []
|
450 |
+
for i in range(input_ids.shape[0]):
|
451 |
+
if context[i] == 0:
|
452 |
+
if input_ids[i] == self.encoder["<ans>"]:
|
453 |
+
# is ans
|
454 |
+
# (segment_id, ans_id)
|
455 |
+
predict_segments.append((segment_ids[i], sub_ans_map[input_id_subs[i]]))
|
456 |
+
else:
|
457 |
+
tmp_input_ids.append(input_ids[i])
|
458 |
+
tmp_input_sub.append(input_id_subs[i])
|
459 |
+
tmp_input_seg.append(segment_ids[i])
|
460 |
+
|
461 |
+
if len(predict_segments) == 0:
|
462 |
+
raise ValueError("No answer to predict")
|
463 |
+
|
464 |
+
input_ids = np.array(tmp_input_ids, dtype=np.int32) # all context
|
465 |
+
input_id_subs = np.array(tmp_input_sub, dtype=np.int32) # [0, 0, 0, 0, 1, 0, 0, 2, 0, ...]
|
466 |
+
context = np.full_like(tmp_input_ids, 1, dtype=np.int8) # [1, 1, 1, ...]
|
467 |
+
segment_ids = np.array(tmp_input_seg, dtype=np.int32) # [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, ...]
|
468 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, 0, ...]
|
469 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, ...]
|
470 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32) # [n_seg, n_seg, n_seg, ...]
|
471 |
+
input_pos = np.arange(input_ids.shape[0], dtype=np.int32) # [0, 1, 2, 3, 4, ...]
|
472 |
+
image_bound = np.array(image_bound)
|
473 |
+
|
474 |
+
return (
|
475 |
+
self.prepare_for_model(
|
476 |
+
input_ids.tolist(),
|
477 |
+
input_id_subs=input_id_subs.tolist(),
|
478 |
+
input_pos=input_pos.tolist(),
|
479 |
+
context=context.tolist(),
|
480 |
+
segment_ids=segment_ids.tolist(),
|
481 |
+
segment_rel_offset=segment_rel_offset.tolist(),
|
482 |
+
segment_rel=segment_rel.tolist(),
|
483 |
+
sample_ids=sample_ids.tolist(),
|
484 |
+
num_segments=num_segments.tolist(),
|
485 |
+
image_bound=image_bound,
|
486 |
+
**kwargs,
|
487 |
+
),
|
488 |
+
predict_segments,
|
489 |
+
answer_placeholders,
|
490 |
+
table_states["ext_table"],
|
491 |
+
table_states["token_id_table"],
|
492 |
+
)
|
493 |
+
|
494 |
+
def _batch_tokenize_cpmbee(self, data_lst, *args, **kwargs):
|
495 |
+
"""
|
496 |
+
Batched _token_cpmbee.
|
497 |
+
"""
|
498 |
+
device = kwargs.get("device", "cpu")
|
499 |
+
return_tensors = kwargs.get("return_tensors", None)
|
500 |
+
batch_outputs = {}
|
501 |
+
segment_rel_pack = []
|
502 |
+
other_info = []
|
503 |
+
|
504 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
505 |
+
batch_ext_table_ids: List[int] = []
|
506 |
+
batch_ext_table_sub: List[int] = []
|
507 |
+
|
508 |
+
for data in data_lst:
|
509 |
+
self.ext_table = {}
|
510 |
+
self.ext_table_rev = {}
|
511 |
+
self.token_id_table = {}
|
512 |
+
(outputs, predict_segments, answer_placeholders, ext_table, token_id_table) = self._tokenize_cpmbee(
|
513 |
+
data,
|
514 |
+
truncation=None,
|
515 |
+
padding=PaddingStrategy.DO_NOT_PAD.value,
|
516 |
+
max_length=None,
|
517 |
+
pad_to_multiple_of=None,
|
518 |
+
return_attention_mask=False,
|
519 |
+
return_tensors=None,
|
520 |
+
)
|
521 |
+
rev_ext_table = {}
|
522 |
+
for token, mp in token_id_table.items():
|
523 |
+
if token == "<ans>":
|
524 |
+
continue
|
525 |
+
token_id = self.encoder[token]
|
526 |
+
for fake_id, token_sub in mp.items():
|
527 |
+
if token_sub > 0:
|
528 |
+
if (token_id, token_sub) not in batch_ext_table_map:
|
529 |
+
batch_ext_table_map[(token_id, token_sub)] = len(batch_ext_table_ids) + self.vocab_size
|
530 |
+
batch_ext_table_ids.append(token_id)
|
531 |
+
batch_ext_table_sub.append(token_sub)
|
532 |
+
rev_ext_table[batch_ext_table_map[(token_id, token_sub)]] = ext_table[fake_id]
|
533 |
+
else:
|
534 |
+
rev_ext_table[token_id] = ext_table[fake_id]
|
535 |
+
|
536 |
+
segment_rel_pack.append(np.array(outputs.pop("segment_rel")))
|
537 |
+
other_info.append(
|
538 |
+
{
|
539 |
+
"predict_segments": predict_segments,
|
540 |
+
"answer_placeholders": answer_placeholders,
|
541 |
+
"ext_table": rev_ext_table,
|
542 |
+
}
|
543 |
+
)
|
544 |
+
|
545 |
+
for key, value in outputs.items():
|
546 |
+
if key not in batch_outputs:
|
547 |
+
batch_outputs[key] = []
|
548 |
+
batch_outputs[key].append(value)
|
549 |
+
|
550 |
+
max_length = max([len(item) for item in batch_outputs[self.model_input_names[0]]])
|
551 |
+
batch_size = len(batch_outputs[self.model_input_names[0]])
|
552 |
+
for i in range(batch_size):
|
553 |
+
inputs = {k: v[i] for k, v in batch_outputs.items()}
|
554 |
+
|
555 |
+
for k, v in inputs.items():
|
556 |
+
required_input = v
|
557 |
+
|
558 |
+
needs_to_be_padded = len(required_input) != max_length and k != 'image_bound'
|
559 |
+
|
560 |
+
if needs_to_be_padded:
|
561 |
+
difference = max_length - len(required_input)
|
562 |
+
batch_outputs[k][i] = [self.pad_token_id] * difference + required_input
|
563 |
+
|
564 |
+
max_num_rels = 0
|
565 |
+
for rel in segment_rel_pack:
|
566 |
+
max_num_rels = max(max_num_rels, rel.shape[0])
|
567 |
+
padded_rels = np.zeros((len(segment_rel_pack), max_num_rels), dtype=np.int32)
|
568 |
+
for i, rel in enumerate(segment_rel_pack):
|
569 |
+
padded_rels[i, : rel.shape[0]] = rel
|
570 |
+
batch_outputs["segment_rel"] = padded_rels
|
571 |
+
batch_outputs["batch_ext_table_ids"] = np.array(batch_ext_table_ids, dtype=np.int32)
|
572 |
+
batch_outputs["batch_ext_table_sub"] = np.array(batch_ext_table_sub, dtype=np.int32)
|
573 |
+
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
574 |
+
if return_tensors == "pt":
|
575 |
+
batch_outputs = batch_outputs.to(device=device)
|
576 |
+
batch_outputs["other_info"] = other_info
|
577 |
+
|
578 |
+
return batch_outputs
|
579 |
+
|
580 |
+
def convert_data_to_id(
|
581 |
+
self,
|
582 |
+
data: Any,
|
583 |
+
prev_ext_states: Optional[_PrevExtTableStates] = None,
|
584 |
+
shuffle_answer: bool = True,
|
585 |
+
max_depth: int = 8,
|
586 |
+
):
|
587 |
+
"""
|
588 |
+
Parse a dict to data ids. Exclusive for CPMBee. It will
|
589 |
+
1. parse the dict to segments and get segment_rel, which for calculating of position_bias.
|
590 |
+
2. tokenize every segment.
|
591 |
+
"""
|
592 |
+
root: _DictTree = {
|
593 |
+
"value": "<root>",
|
594 |
+
"children": [],
|
595 |
+
"depth": 0,
|
596 |
+
"segment_id": 0,
|
597 |
+
"need_predict": False,
|
598 |
+
"is_image": False
|
599 |
+
}
|
600 |
+
|
601 |
+
segments = [root]
|
602 |
+
|
603 |
+
def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
|
604 |
+
if isinstance(data, dict):
|
605 |
+
ret_list: List[_DictTree] = []
|
606 |
+
curr_items = list(data.items())
|
607 |
+
if need_predict and shuffle_answer:
|
608 |
+
access_idx = np.arange(len(curr_items))
|
609 |
+
np.random.shuffle(access_idx)
|
610 |
+
curr_items = [curr_items[idx] for idx in access_idx]
|
611 |
+
for k, v in curr_items:
|
612 |
+
child_info: _DictTree = {
|
613 |
+
"value": k,
|
614 |
+
"children": [],
|
615 |
+
"depth": depth,
|
616 |
+
"segment_id": len(segments),
|
617 |
+
"need_predict": False, # only leaves are contexts
|
618 |
+
"is_image": False,
|
619 |
+
}
|
620 |
+
segments.append(child_info)
|
621 |
+
child_info["children"] = _build_dict_tree(
|
622 |
+
v, depth + 1,
|
623 |
+
need_predict=need_predict or (depth == 1 and k == "<ans>"),
|
624 |
+
is_image=is_image or (depth == 1 and k == "image")
|
625 |
+
) # elements in <root>.<ans>
|
626 |
+
|
627 |
+
ret_list.append(child_info)
|
628 |
+
return ret_list
|
629 |
+
else:
|
630 |
+
assert isinstance(data, str), "Invalid data {}".format(data)
|
631 |
+
ret: _DictTree = {
|
632 |
+
"value": data,
|
633 |
+
"children": [],
|
634 |
+
"depth": depth,
|
635 |
+
"segment_id": len(segments),
|
636 |
+
"need_predict": need_predict,
|
637 |
+
"is_image": is_image,
|
638 |
+
}
|
639 |
+
segments.append(ret)
|
640 |
+
return [ret]
|
641 |
+
|
642 |
+
root["children"] = _build_dict_tree(data, 1, False, False)
|
643 |
+
|
644 |
+
num_segments = len(segments)
|
645 |
+
segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
|
646 |
+
|
647 |
+
def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
|
648 |
+
ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
|
649 |
+
for child in node["children"]:
|
650 |
+
sub = _build_segment_rel(child)
|
651 |
+
for seg_id_1, depth_1 in sub:
|
652 |
+
for seg_id_2, depth_2 in ret:
|
653 |
+
n_up = min(depth_1 - node["depth"], max_depth - 1)
|
654 |
+
n_down = min(depth_2 - node["depth"], max_depth - 1)
|
655 |
+
segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
|
656 |
+
n_up, n_down, max_depth=max_depth
|
657 |
+
)
|
658 |
+
segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
|
659 |
+
n_down, n_up, max_depth=max_depth
|
660 |
+
)
|
661 |
+
ret.extend(sub)
|
662 |
+
return ret
|
663 |
+
|
664 |
+
_build_segment_rel(root)
|
665 |
+
|
666 |
+
input_ids: List[int] = []
|
667 |
+
input_id_subs: List[int] = []
|
668 |
+
segment_bound: List[Tuple[int, int]] = []
|
669 |
+
image_bound: List[Tuple[int, int]] = []
|
670 |
+
|
671 |
+
|
672 |
+
if prev_ext_states is not None:
|
673 |
+
self.ext_table = prev_ext_states["ext_table"]
|
674 |
+
self.token_id_table = prev_ext_states["token_id_table"]
|
675 |
+
|
676 |
+
for seg in segments:
|
677 |
+
# tokenize
|
678 |
+
tokens = self.convert_tokens_to_ids(self.tokenize(seg["value"], for_cpmbee=True))
|
679 |
+
|
680 |
+
token_id_subs = []
|
681 |
+
reid_token_ids = []
|
682 |
+
for idx in tokens:
|
683 |
+
if idx in self.ext_table:
|
684 |
+
# unk or special token
|
685 |
+
token = self.ext_table[idx]
|
686 |
+
if token.startswith("<") and token.endswith(">"):
|
687 |
+
# special token
|
688 |
+
if "_" in token:
|
689 |
+
token_name = token[1:-1].split("_", maxsplit=1)[0]
|
690 |
+
else:
|
691 |
+
token_name = token[1:-1]
|
692 |
+
token_name = "<{}>".format(token_name)
|
693 |
+
else:
|
694 |
+
token_name = "<unk>"
|
695 |
+
|
696 |
+
if token_name not in self.token_id_table:
|
697 |
+
self.token_id_table[token_name] = {}
|
698 |
+
if idx not in self.token_id_table[token_name]:
|
699 |
+
self.token_id_table[token_name][idx] = len(self.token_id_table[token_name])
|
700 |
+
if token_name not in self.encoder:
|
701 |
+
raise ValueError("Invalid token {}".format(token))
|
702 |
+
reid_token_ids.append(self.encoder[token_name])
|
703 |
+
token_id_subs.append(self.token_id_table[token_name][idx])
|
704 |
+
else:
|
705 |
+
reid_token_ids.append(idx)
|
706 |
+
token_id_subs.append(0)
|
707 |
+
tokens = [self.bos_token_id] + reid_token_ids
|
708 |
+
token_id_subs = [0] + token_id_subs
|
709 |
+
# eos_id ่กจ็คบ no need_predict
|
710 |
+
if not seg["need_predict"]: # eos
|
711 |
+
tokens = tokens + [self.eos_token_id]
|
712 |
+
token_id_subs = token_id_subs + [0]
|
713 |
+
else:
|
714 |
+
# no eos
|
715 |
+
pass
|
716 |
+
begin = len(input_ids)
|
717 |
+
input_ids.extend(tokens)
|
718 |
+
input_id_subs.extend(token_id_subs)
|
719 |
+
end = len(input_ids)
|
720 |
+
segment_bound.append((begin, end))
|
721 |
+
|
722 |
+
ids = np.array(input_ids, dtype=np.int32)
|
723 |
+
id_subs = np.array(input_id_subs, dtype=np.int32)
|
724 |
+
segs = np.zeros((ids.shape[0],), dtype=np.int32) # ๆsegment_boundๅฏนseg็ผๅท
|
725 |
+
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
726 |
+
for i, (begin, end) in enumerate(segment_bound):
|
727 |
+
if not segments[i]["need_predict"]:
|
728 |
+
context[begin:end] = 1
|
729 |
+
if segments[i]["is_image"]:
|
730 |
+
image_bound.append((begin + 1, end - 1))
|
731 |
+
segs[begin:end] = i
|
732 |
+
|
733 |
+
curr_ext_table_states: _PrevExtTableStates = {
|
734 |
+
"ext_table": self.ext_table,
|
735 |
+
"token_id_table": self.token_id_table,
|
736 |
+
}
|
737 |
+
image_bound = np.array(image_bound, dtype=np.int32)
|
738 |
+
return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
|
739 |
+
|
740 |
+
def prepare_for_model(
|
741 |
+
self,
|
742 |
+
ids: List[int],
|
743 |
+
pair_ids: Optional[List[int]] = None,
|
744 |
+
add_special_tokens: bool = True,
|
745 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
746 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
747 |
+
max_length: Optional[int] = None,
|
748 |
+
stride: int = 0,
|
749 |
+
pad_to_multiple_of: Optional[int] = None,
|
750 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
751 |
+
return_token_type_ids: Optional[bool] = None,
|
752 |
+
return_attention_mask: Optional[bool] = None,
|
753 |
+
return_overflowing_tokens: bool = False,
|
754 |
+
return_special_tokens_mask: bool = False,
|
755 |
+
return_length: bool = False,
|
756 |
+
verbose: bool = True,
|
757 |
+
prepend_batch_axis: bool = False,
|
758 |
+
**kwargs,
|
759 |
+
) -> BatchEncoding:
|
760 |
+
"""
|
761 |
+
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
|
762 |
+
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
763 |
+
manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
|
764 |
+
different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
|
765 |
+
overflowing tokens. Such a combination of arguments will raise an error.
|
766 |
+
|
767 |
+
Args:
|
768 |
+
ids (`List[int]`):
|
769 |
+
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
|
770 |
+
`convert_tokens_to_ids` methods.
|
771 |
+
pair_ids (`List[int]`, *optional*):
|
772 |
+
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
|
773 |
+
and `convert_tokens_to_ids` methods.
|
774 |
+
"""
|
775 |
+
|
776 |
+
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
777 |
+
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
778 |
+
padding=padding,
|
779 |
+
truncation=truncation,
|
780 |
+
max_length=max_length,
|
781 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
782 |
+
verbose=verbose,
|
783 |
+
**kwargs,
|
784 |
+
)
|
785 |
+
|
786 |
+
pair = bool(pair_ids is not None)
|
787 |
+
len_ids = len(ids)
|
788 |
+
len_pair_ids = len(pair_ids) if pair else 0
|
789 |
+
|
790 |
+
if return_token_type_ids and not add_special_tokens:
|
791 |
+
raise ValueError(
|
792 |
+
"Asking to return token_type_ids while setting add_special_tokens to False "
|
793 |
+
"results in an undefined behavior. Please set add_special_tokens to True or "
|
794 |
+
"set return_token_type_ids to None."
|
795 |
+
)
|
796 |
+
|
797 |
+
if (
|
798 |
+
return_overflowing_tokens
|
799 |
+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
|
800 |
+
and pair_ids is not None
|
801 |
+
):
|
802 |
+
raise ValueError(
|
803 |
+
"Not possible to return overflowing tokens for pair of sequences with the "
|
804 |
+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
805 |
+
"for instance `only_second` or `only_first`."
|
806 |
+
)
|
807 |
+
|
808 |
+
# Load from model defaults
|
809 |
+
if return_token_type_ids is None:
|
810 |
+
return_token_type_ids = "token_type_ids" in self.model_input_names
|
811 |
+
if return_attention_mask is None:
|
812 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
813 |
+
|
814 |
+
encoded_inputs = {}
|
815 |
+
|
816 |
+
# Compute the total size of the returned encodings
|
817 |
+
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
818 |
+
|
819 |
+
# Truncation: Handle max sequence length
|
820 |
+
overflowing_tokens = []
|
821 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
822 |
+
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
823 |
+
ids,
|
824 |
+
pair_ids=pair_ids,
|
825 |
+
num_tokens_to_remove=total_len - max_length,
|
826 |
+
truncation_strategy=truncation_strategy,
|
827 |
+
stride=stride,
|
828 |
+
)
|
829 |
+
|
830 |
+
if return_overflowing_tokens:
|
831 |
+
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
832 |
+
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
833 |
+
|
834 |
+
# Add special tokens
|
835 |
+
if add_special_tokens:
|
836 |
+
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
837 |
+
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
838 |
+
else:
|
839 |
+
sequence = ids + pair_ids if pair else ids
|
840 |
+
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
|
841 |
+
|
842 |
+
# Build output dictionary
|
843 |
+
encoded_inputs["input_ids"] = sequence
|
844 |
+
if return_token_type_ids:
|
845 |
+
encoded_inputs["token_type_ids"] = token_type_ids
|
846 |
+
if return_special_tokens_mask:
|
847 |
+
if add_special_tokens:
|
848 |
+
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
849 |
+
else:
|
850 |
+
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
851 |
+
|
852 |
+
# Check lengths
|
853 |
+
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
854 |
+
|
855 |
+
# Padding
|
856 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
857 |
+
encoded_inputs = self.pad(
|
858 |
+
encoded_inputs,
|
859 |
+
max_length=max_length,
|
860 |
+
padding=padding_strategy.value,
|
861 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
862 |
+
return_attention_mask=return_attention_mask,
|
863 |
+
)
|
864 |
+
|
865 |
+
if return_length:
|
866 |
+
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
867 |
+
|
868 |
+
# for CPMBee, encode all the model arguments
|
869 |
+
for arg in self.ext_args_for_model:
|
870 |
+
v = kwargs.get(arg, None)
|
871 |
+
if v is not None:
|
872 |
+
encoded_inputs[arg] = v
|
873 |
+
|
874 |
+
batch_outputs = BatchEncoding(
|
875 |
+
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
876 |
+
)
|
877 |
+
|
878 |
+
return batch_outputs
|
879 |
+
|
880 |
+
def prepare_for_finetune(
|
881 |
+
self,
|
882 |
+
data_list: List[Dict],
|
883 |
+
max_length: int = 2048
|
884 |
+
):
|
885 |
+
_inputs: List[NDArray[np.int32]] = []
|
886 |
+
_inputs_sub: List[NDArray[np.int32]] = []
|
887 |
+
_context: List[NDArray[np.int8]] = []
|
888 |
+
_sample_ids: List[NDArray[np.int32]] = []
|
889 |
+
_segments: List[NDArray[np.int32]] = []
|
890 |
+
_num_segments: List[NDArray[np.int32]] = []
|
891 |
+
_segment_rel_offset: List[NDArray[np.int32]] = []
|
892 |
+
_segment_rel: List[NDArray[np.int32]] = []
|
893 |
+
_spans: List[List[int]] = []
|
894 |
+
_raw_data: List[List[Any]] = []
|
895 |
+
|
896 |
+
raw_data = {}
|
897 |
+
for data in data_list:
|
898 |
+
(
|
899 |
+
input_ids,
|
900 |
+
input_id_subs,
|
901 |
+
context,
|
902 |
+
segment_ids,
|
903 |
+
segment_rel,
|
904 |
+
n_segments,
|
905 |
+
_
|
906 |
+
) = self.convert_data_to_id(data)
|
907 |
+
|
908 |
+
input_ids = input_ids[: max_length]
|
909 |
+
context = context[: max_length]
|
910 |
+
segment_ids = segment_ids[: max_length]
|
911 |
+
raw_data["input"] = data
|
912 |
+
raw_data["samples"] = []
|
913 |
+
|
914 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
915 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
|
916 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
|
917 |
+
|
918 |
+
_inputs.append(input_ids)
|
919 |
+
_inputs_sub.append(input_id_subs)
|
920 |
+
_context.append(context)
|
921 |
+
_sample_ids.append(sample_ids)
|
922 |
+
_segments.append(segment_ids)
|
923 |
+
_num_segments.append(num_segments)
|
924 |
+
_segment_rel_offset.append(segment_rel_offset)
|
925 |
+
_segment_rel.append(segment_rel)
|
926 |
+
_spans.append([input_ids.shape[0]])
|
927 |
+
_raw_data.append([raw_data])
|
928 |
+
|
929 |
+
batch_size = len(_inputs)
|
930 |
+
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
931 |
+
inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
|
932 |
+
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
933 |
+
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
934 |
+
segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
935 |
+
num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
936 |
+
segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
|
937 |
+
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
938 |
+
|
939 |
+
max_rel = 0
|
940 |
+
for i in range(batch_size):
|
941 |
+
max_rel = max(max_rel, _segment_rel[i].shape[0])
|
942 |
+
segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
|
943 |
+
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
944 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
945 |
+
|
946 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
947 |
+
batch_ext_table_ids: List[int] = []
|
948 |
+
batch_ext_table_sub: List[int] = []
|
949 |
+
raw_data_list: List[Any] = []
|
950 |
+
|
951 |
+
for i in range(batch_size):
|
952 |
+
instance_length = _inputs[i].shape[0]
|
953 |
+
rel_size = _segment_rel[i].shape[0]
|
954 |
+
inputs[i, :instance_length] = _inputs[i]
|
955 |
+
inputs_sub[i, :instance_length] = _inputs_sub[i]
|
956 |
+
context[i, :instance_length] = _context[i]
|
957 |
+
sample_ids[i, :instance_length] = _sample_ids[i]
|
958 |
+
segments[i, :instance_length] = _segments[i]
|
959 |
+
num_segments[i, :instance_length] = _num_segments[i]
|
960 |
+
segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
|
961 |
+
segment_rel[i, :rel_size] = _segment_rel[i]
|
962 |
+
|
963 |
+
span_begin = 0
|
964 |
+
for span_id, span_end in enumerate(_spans[i]):
|
965 |
+
spans[i, span_begin:span_end] = span_id
|
966 |
+
span_begin = span_end
|
967 |
+
length[i] = instance_length
|
968 |
+
raw_data_list.extend(_raw_data[i])
|
969 |
+
|
970 |
+
for j in range(instance_length):
|
971 |
+
idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
|
972 |
+
tgt_idx = idx
|
973 |
+
if idx_sub > 0:
|
974 |
+
# need to be in ext table
|
975 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
976 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
977 |
+
batch_ext_table_ids.append(idx)
|
978 |
+
batch_ext_table_sub.append(idx_sub)
|
979 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
|
980 |
+
if j > 1 and context[i, j - 1] == 0:
|
981 |
+
if idx != self.bos_token_id:
|
982 |
+
tgt[i, j - 1] = tgt_idx
|
983 |
+
else:
|
984 |
+
tgt[i, j - 1] = self.eos_token_id
|
985 |
+
if context[i, instance_length - 1] == 0:
|
986 |
+
tgt[i, instance_length - 1] = self.eos_token_id
|
987 |
+
|
988 |
+
if len(batch_ext_table_map) == 0:
|
989 |
+
# placeholder
|
990 |
+
batch_ext_table_ids.append(0)
|
991 |
+
batch_ext_table_sub.append(1)
|
992 |
+
|
993 |
+
return BatchEncoding({
|
994 |
+
"input_ids": inputs,
|
995 |
+
"input_id_sub": inputs_sub,
|
996 |
+
"length": length,
|
997 |
+
"context": context > 0,
|
998 |
+
"sample_ids": sample_ids,
|
999 |
+
"num_segments": num_segments,
|
1000 |
+
"segment": segments,
|
1001 |
+
"segment_rel_offset": segment_rel_offset,
|
1002 |
+
"segment_rel": segment_rel,
|
1003 |
+
"span": spans,
|
1004 |
+
"labels": tgt,
|
1005 |
+
"ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
|
1006 |
+
"ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
|
1007 |
+
}, tensor_type="pt")
|
tokenizer_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "openbmb/viscpmchat-bee-10b",
|
3 |
+
"tokenizer_class": "VisCpmChatBeeTokenizer",
|
4 |
+
"auto_map": {
|
5 |
+
"AutoTokenizer": [
|
6 |
+
"tokenization_viscpmchatbee.VisCpmChatBeeTokenizer",
|
7 |
+
null
|
8 |
+
]
|
9 |
+
}
|
10 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
5 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
6 |
+
from torchvision import transforms
|
7 |
+
import urllib
|
8 |
+
from tqdm import tqdm
|
9 |
+
from cpm_live.tokenizers import CPMBeeTokenizer
|
10 |
+
from torch.utils.data import default_collate
|
11 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
12 |
+
from typing_extensions import TypedDict
|
13 |
+
from numpy.typing import NDArray
|
14 |
+
import importlib.machinery
|
15 |
+
import importlib.util
|
16 |
+
import types
|
17 |
+
import random
|
18 |
+
|
19 |
+
|
20 |
+
CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
|
21 |
+
|
22 |
+
|
23 |
+
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
|
24 |
+
items = []
|
25 |
+
if isinstance(orig_items[0][key], list):
|
26 |
+
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
27 |
+
for it in orig_items:
|
28 |
+
for tr in it[key]:
|
29 |
+
items.append({key: tr})
|
30 |
+
else:
|
31 |
+
assert isinstance(orig_items[0][key], torch.Tensor)
|
32 |
+
items = orig_items
|
33 |
+
|
34 |
+
batch_size = len(items)
|
35 |
+
shape = items[0][key].shape
|
36 |
+
dim = len(shape)
|
37 |
+
assert dim <= 3
|
38 |
+
if max_length is None:
|
39 |
+
max_length = 0
|
40 |
+
max_length = max(max_length, max(item[key].shape[-1] for item in items))
|
41 |
+
min_length = min(item[key].shape[-1] for item in items)
|
42 |
+
dtype = items[0][key].dtype
|
43 |
+
|
44 |
+
if dim == 1:
|
45 |
+
return torch.cat([item[key] for item in items], dim=0)
|
46 |
+
elif dim == 2:
|
47 |
+
if max_length == min_length:
|
48 |
+
return torch.cat([item[key] for item in items], dim=0)
|
49 |
+
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
50 |
+
else:
|
51 |
+
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
52 |
+
|
53 |
+
for i, item in enumerate(items):
|
54 |
+
if dim == 2:
|
55 |
+
if padding_side == "left":
|
56 |
+
tensor[i, -len(item[key][0]):] = item[key][0].clone()
|
57 |
+
else:
|
58 |
+
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
59 |
+
elif dim == 3:
|
60 |
+
if padding_side == "left":
|
61 |
+
tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
|
62 |
+
else:
|
63 |
+
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
64 |
+
|
65 |
+
return tensor
|
66 |
+
|
67 |
+
|
68 |
+
class CPMBeeCollater:
|
69 |
+
"""
|
70 |
+
้ๅฏน cpmbee ่พๅ
ฅๆฐๆฎ collate, ๅฏนๅบ cpm-live ็ _MixedDatasetBatchPacker
|
71 |
+
็ฎๅๅฉ็จ torch ็ๅ็ Dataloader ไธๅคช้ๅๆน้ in-context-learning
|
72 |
+
ๅนถไธๅๆฅๅฎ็ฐไธบไบๆๅคงๅๆ้ซๆๆ token ๆฏๆฏไพ, ไผๆไธไธช best_fit ๆไฝ, ่ฟไธช็ฎๅไนไธๆฏๆ
|
73 |
+
todo: @wangchongyi ้ๅไธไธ Dataloader or BatchPacker
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, tokenizer: CPMBeeTokenizer, max_len):
|
77 |
+
self.tokenizer = tokenizer
|
78 |
+
self._max_length = max_len
|
79 |
+
self.pad_keys = ['input_ids', 'input_id_subs', 'context', 'segment_ids', 'segment_rel_offset',
|
80 |
+
'segment_rel', 'sample_ids', 'num_segments']
|
81 |
+
|
82 |
+
def __call__(self, batch):
|
83 |
+
batch_size = len(batch)
|
84 |
+
|
85 |
+
tgt = np.full((batch_size, self._max_length), -100, dtype=np.int32)
|
86 |
+
# ็ฎๅๆฒกๆ best_fit, span ไธบๅ
จ 0
|
87 |
+
span = np.zeros((batch_size, self._max_length), dtype=np.int32)
|
88 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
89 |
+
|
90 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
91 |
+
batch_ext_table_ids: List[int] = []
|
92 |
+
batch_ext_table_sub: List[int] = []
|
93 |
+
raw_data_list: List[Any] = []
|
94 |
+
|
95 |
+
for i in range(batch_size):
|
96 |
+
instance_length = batch[i]['input_ids'][0].shape[0]
|
97 |
+
length[i] = instance_length
|
98 |
+
raw_data_list.extend(batch[i]['raw_data'])
|
99 |
+
|
100 |
+
for j in range(instance_length):
|
101 |
+
idx, idx_sub = batch[i]['input_ids'][0, j], batch[i]['input_id_subs'][0, j]
|
102 |
+
tgt_idx = idx
|
103 |
+
if idx_sub > 0:
|
104 |
+
# need to be in ext table
|
105 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
106 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
107 |
+
batch_ext_table_ids.append(idx)
|
108 |
+
batch_ext_table_sub.append(idx_sub)
|
109 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.tokenizer.vocab_size
|
110 |
+
if j > 1 and batch[i]['context'][0, j - 1] == 0:
|
111 |
+
if idx != self.tokenizer.bos_id:
|
112 |
+
tgt[i, j - 1] = tgt_idx
|
113 |
+
else:
|
114 |
+
tgt[i, j - 1] = self.tokenizer.eos_id
|
115 |
+
if batch[i]['context'][0, instance_length - 1] == 0:
|
116 |
+
tgt[i, instance_length - 1] = self.tokenizer.eos_id
|
117 |
+
|
118 |
+
if len(batch_ext_table_map) == 0:
|
119 |
+
# placeholder
|
120 |
+
batch_ext_table_ids.append(0)
|
121 |
+
batch_ext_table_sub.append(1)
|
122 |
+
|
123 |
+
# image
|
124 |
+
if 'pixel_values' in batch[0]:
|
125 |
+
data = {'pixel_values': default_collate([i['pixel_values'] for i in batch])}
|
126 |
+
else:
|
127 |
+
data = {}
|
128 |
+
|
129 |
+
# image_bound
|
130 |
+
if 'image_bound' in batch[0]:
|
131 |
+
data['image_bound'] = default_collate([i['image_bound'] for i in batch])
|
132 |
+
|
133 |
+
# bee inp
|
134 |
+
for key in self.pad_keys:
|
135 |
+
data[key] = pad(batch, key, max_length=self._max_length, padding_value=0, padding_side='right')
|
136 |
+
|
137 |
+
data['context'] = data['context'] > 0
|
138 |
+
data['length'] = torch.from_numpy(length)
|
139 |
+
data['span'] = torch.from_numpy(span)
|
140 |
+
data['target'] = torch.from_numpy(tgt)
|
141 |
+
data['ext_table_ids'] = torch.from_numpy(np.array(batch_ext_table_ids))
|
142 |
+
data['ext_table_sub'] = torch.from_numpy(np.array(batch_ext_table_sub))
|
143 |
+
data['raw_data'] = raw_data_list
|
144 |
+
|
145 |
+
return data
|
146 |
+
|
147 |
+
|
148 |
+
class _DictTree(TypedDict):
|
149 |
+
value: str
|
150 |
+
children: List["_DictTree"]
|
151 |
+
depth: int
|
152 |
+
segment_id: int
|
153 |
+
need_predict: bool
|
154 |
+
is_image: bool
|
155 |
+
|
156 |
+
|
157 |
+
class _PrevExtTableStates(TypedDict):
|
158 |
+
ext_table: Dict[int, str]
|
159 |
+
token_id_table: Dict[str, Dict[int, int]]
|
160 |
+
|
161 |
+
|
162 |
+
class _TransformFuncDict(TypedDict):
|
163 |
+
loader: importlib.machinery.SourceFileLoader
|
164 |
+
module: types.ModuleType
|
165 |
+
last_m: float
|
166 |
+
|
167 |
+
|
168 |
+
_TransformFunction = Callable[[CPMBeeInputType, int, random.Random], CPMBeeInputType]
|
169 |
+
|
170 |
+
|
171 |
+
class CPMBeeBatch(TypedDict):
|
172 |
+
inputs: NDArray[np.int32]
|
173 |
+
inputs_sub: NDArray[np.int32]
|
174 |
+
length: NDArray[np.int32]
|
175 |
+
context: NDArray[np.bool_]
|
176 |
+
sample_ids: NDArray[np.int32]
|
177 |
+
num_segments: NDArray[np.int32]
|
178 |
+
segment_ids: NDArray[np.int32]
|
179 |
+
segment_rel_offset: NDArray[np.int32]
|
180 |
+
segment_rel: NDArray[np.int32]
|
181 |
+
spans: NDArray[np.int32]
|
182 |
+
target: NDArray[np.int32]
|
183 |
+
ext_ids: NDArray[np.int32]
|
184 |
+
ext_sub: NDArray[np.int32]
|
185 |
+
task_ids: NDArray[np.int32]
|
186 |
+
task_names: List[str]
|
187 |
+
raw_data: List[Any]
|
188 |
+
|
189 |
+
|
190 |
+
def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
|
191 |
+
ret = n_up * max_depth + n_down
|
192 |
+
if ret == 0:
|
193 |
+
return ret
|
194 |
+
else:
|
195 |
+
# bucket 1 is reserved for incontext samples
|
196 |
+
return ret + 1
|
197 |
+
|
198 |
+
|
199 |
+
def convert_data_to_id(
|
200 |
+
tokenizer: CPMBeeTokenizer,
|
201 |
+
data: Any,
|
202 |
+
prev_ext_states: Optional[_PrevExtTableStates] = None,
|
203 |
+
shuffle_answer: bool = True,
|
204 |
+
max_depth: int = 8
|
205 |
+
):
|
206 |
+
root: _DictTree = {
|
207 |
+
"value": "<root>",
|
208 |
+
"children": [],
|
209 |
+
"depth": 0,
|
210 |
+
"segment_id": 0,
|
211 |
+
"need_predict": False,
|
212 |
+
"is_image": False
|
213 |
+
}
|
214 |
+
|
215 |
+
segments = [root]
|
216 |
+
|
217 |
+
def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
|
218 |
+
if isinstance(data, dict):
|
219 |
+
ret_list: List[_DictTree] = []
|
220 |
+
curr_items = list(data.items())
|
221 |
+
if need_predict and shuffle_answer:
|
222 |
+
access_idx = np.arange(len(curr_items))
|
223 |
+
np.random.shuffle(access_idx)
|
224 |
+
curr_items = [curr_items[idx] for idx in access_idx]
|
225 |
+
for k, v in curr_items:
|
226 |
+
child_info: _DictTree = {
|
227 |
+
"value": k,
|
228 |
+
"children": [],
|
229 |
+
"depth": depth,
|
230 |
+
"segment_id": len(segments),
|
231 |
+
"need_predict": False, # only leaves are contexts
|
232 |
+
"is_image": False,
|
233 |
+
}
|
234 |
+
segments.append(child_info)
|
235 |
+
child_info["children"] = _build_dict_tree(
|
236 |
+
v, depth + 1,
|
237 |
+
need_predict=need_predict or (depth == 1 and k == "<ans>"),
|
238 |
+
is_image=is_image or (depth == 1 and k == "image")
|
239 |
+
) # elements in <root>.<ans>
|
240 |
+
|
241 |
+
ret_list.append(child_info)
|
242 |
+
return ret_list
|
243 |
+
else:
|
244 |
+
assert isinstance(data, str), "Invalid data {}".format(data)
|
245 |
+
ret: _DictTree = {
|
246 |
+
"value": data,
|
247 |
+
"children": [],
|
248 |
+
"depth": depth,
|
249 |
+
"segment_id": len(segments),
|
250 |
+
"need_predict": need_predict,
|
251 |
+
"is_image": is_image,
|
252 |
+
}
|
253 |
+
segments.append(ret)
|
254 |
+
return [ret]
|
255 |
+
|
256 |
+
root["children"] = _build_dict_tree(data, 1, False, False)
|
257 |
+
|
258 |
+
num_segments = len(segments)
|
259 |
+
segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
|
260 |
+
|
261 |
+
def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
|
262 |
+
ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
|
263 |
+
for child in node["children"]:
|
264 |
+
sub = _build_segment_rel(child)
|
265 |
+
for seg_id_1, depth_1 in sub:
|
266 |
+
for seg_id_2, depth_2 in ret:
|
267 |
+
n_up = min(depth_1 - node["depth"], max_depth - 1)
|
268 |
+
n_down = min(depth_2 - node["depth"], max_depth - 1)
|
269 |
+
segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
|
270 |
+
n_up, n_down, max_depth=max_depth
|
271 |
+
)
|
272 |
+
segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
|
273 |
+
n_down, n_up, max_depth=max_depth
|
274 |
+
)
|
275 |
+
ret.extend(sub)
|
276 |
+
return ret
|
277 |
+
|
278 |
+
_build_segment_rel(root)
|
279 |
+
|
280 |
+
input_ids: List[int] = []
|
281 |
+
input_id_subs: List[int] = []
|
282 |
+
segment_bound: List[Tuple[int, int]] = []
|
283 |
+
image_bound: List[Tuple[int, int]] = []
|
284 |
+
|
285 |
+
ext_table: Dict[int, str] = {}
|
286 |
+
token_id_table: Dict[str, Dict[int, int]] = {}
|
287 |
+
|
288 |
+
if prev_ext_states is not None:
|
289 |
+
ext_table = prev_ext_states["ext_table"]
|
290 |
+
token_id_table = prev_ext_states["token_id_table"]
|
291 |
+
|
292 |
+
for seg in segments:
|
293 |
+
tokens, ext_table = tokenizer.encode(seg["value"], ext_table)
|
294 |
+
|
295 |
+
token_id_subs = []
|
296 |
+
reid_token_ids = []
|
297 |
+
for idx in tokens:
|
298 |
+
if idx in ext_table:
|
299 |
+
# unk or special token
|
300 |
+
token = ext_table[idx]
|
301 |
+
if token.startswith("<") and token.endswith(">"):
|
302 |
+
# special token
|
303 |
+
if "_" in token:
|
304 |
+
token_name = token[1:-1].split("_", maxsplit=1)[0]
|
305 |
+
else:
|
306 |
+
token_name = token[1:-1]
|
307 |
+
token_name = "<{}>".format(token_name)
|
308 |
+
else:
|
309 |
+
token_name = "<unk>"
|
310 |
+
|
311 |
+
if token_name not in token_id_table:
|
312 |
+
token_id_table[token_name] = {}
|
313 |
+
if idx not in token_id_table[token_name]:
|
314 |
+
token_id_table[token_name][idx] = len(token_id_table[token_name])
|
315 |
+
if token_name not in tokenizer.encoder:
|
316 |
+
raise ValueError("Invalid token {}".format(token))
|
317 |
+
reid_token_ids.append(tokenizer.encoder[token_name])
|
318 |
+
token_id_subs.append(token_id_table[token_name][idx])
|
319 |
+
else:
|
320 |
+
reid_token_ids.append(idx)
|
321 |
+
token_id_subs.append(0)
|
322 |
+
tokens = [tokenizer.bos_id] + reid_token_ids
|
323 |
+
token_id_subs = [0] + token_id_subs
|
324 |
+
if not seg["need_predict"]:
|
325 |
+
tokens = tokens + [tokenizer.eos_id]
|
326 |
+
token_id_subs = token_id_subs + [0]
|
327 |
+
else:
|
328 |
+
# no eos
|
329 |
+
pass
|
330 |
+
begin = len(input_ids)
|
331 |
+
input_ids.extend(tokens)
|
332 |
+
input_id_subs.extend(token_id_subs)
|
333 |
+
end = len(input_ids)
|
334 |
+
segment_bound.append((begin, end))
|
335 |
+
|
336 |
+
ids = np.array(input_ids, dtype=np.int32)
|
337 |
+
id_subs = np.array(input_id_subs, dtype=np.int32)
|
338 |
+
segs = np.zeros((ids.shape[0],), dtype=np.int32)
|
339 |
+
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
340 |
+
for i, (begin, end) in enumerate(segment_bound):
|
341 |
+
if not segments[i]["need_predict"]:
|
342 |
+
context[begin:end] = 1
|
343 |
+
if segments[i]["is_image"]:
|
344 |
+
image_bound.append((begin+1, end-1))
|
345 |
+
segs[begin:end] = i
|
346 |
+
|
347 |
+
curr_ext_table_states: _PrevExtTableStates = {
|
348 |
+
"ext_table": ext_table,
|
349 |
+
"token_id_table": token_id_table,
|
350 |
+
}
|
351 |
+
image_bound = np.array(image_bound, dtype=np.int32)
|
352 |
+
return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
|
353 |
+
|
354 |
+
|
355 |
+
# aug functions
|
356 |
+
def identity_func(img):
|
357 |
+
return img
|
358 |
+
|
359 |
+
|
360 |
+
def autocontrast_func(img, cutoff=0):
|
361 |
+
'''
|
362 |
+
same output as PIL.ImageOps.autocontrast
|
363 |
+
'''
|
364 |
+
n_bins = 256
|
365 |
+
|
366 |
+
def tune_channel(ch):
|
367 |
+
n = ch.size
|
368 |
+
cut = cutoff * n // 100
|
369 |
+
if cut == 0:
|
370 |
+
high, low = ch.max(), ch.min()
|
371 |
+
else:
|
372 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
373 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
374 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
375 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
376 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
377 |
+
if high <= low:
|
378 |
+
table = np.arange(n_bins)
|
379 |
+
else:
|
380 |
+
scale = (n_bins - 1) / (high - low)
|
381 |
+
table = np.arange(n_bins) * scale - low * scale
|
382 |
+
table[table < 0] = 0
|
383 |
+
table[table > n_bins - 1] = n_bins - 1
|
384 |
+
table = table.clip(0, 255).astype(np.uint8)
|
385 |
+
return table[ch]
|
386 |
+
|
387 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
388 |
+
out = cv2.merge(channels)
|
389 |
+
return out
|
390 |
+
|
391 |
+
|
392 |
+
def equalize_func(img):
|
393 |
+
'''
|
394 |
+
same output as PIL.ImageOps.equalize
|
395 |
+
PIL's implementation is different from cv2.equalize
|
396 |
+
'''
|
397 |
+
n_bins = 256
|
398 |
+
|
399 |
+
def tune_channel(ch):
|
400 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
401 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
402 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
403 |
+
if step == 0:
|
404 |
+
return ch
|
405 |
+
n = np.empty_like(hist)
|
406 |
+
n[0] = step // 2
|
407 |
+
n[1:] = hist[:-1]
|
408 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
409 |
+
return table[ch]
|
410 |
+
|
411 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
412 |
+
out = cv2.merge(channels)
|
413 |
+
return out
|
414 |
+
|
415 |
+
|
416 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
417 |
+
'''
|
418 |
+
like PIL, rotate by degree, not radians
|
419 |
+
'''
|
420 |
+
H, W = img.shape[0], img.shape[1]
|
421 |
+
center = W / 2, H / 2
|
422 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
423 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
424 |
+
return out
|
425 |
+
|
426 |
+
|
427 |
+
def solarize_func(img, thresh=128):
|
428 |
+
'''
|
429 |
+
same output as PIL.ImageOps.posterize
|
430 |
+
'''
|
431 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
432 |
+
table = table.clip(0, 255).astype(np.uint8)
|
433 |
+
out = table[img]
|
434 |
+
return out
|
435 |
+
|
436 |
+
|
437 |
+
def color_func(img, factor):
|
438 |
+
'''
|
439 |
+
same output as PIL.ImageEnhance.Color
|
440 |
+
'''
|
441 |
+
# implementation according to PIL definition, quite slow
|
442 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
443 |
+
# out = blend(degenerate, img, factor)
|
444 |
+
# M = (
|
445 |
+
# np.eye(3) * factor
|
446 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
447 |
+
# )[np.newaxis, np.newaxis, :]
|
448 |
+
M = (
|
449 |
+
np.float32([
|
450 |
+
[0.886, -0.114, -0.114],
|
451 |
+
[-0.587, 0.413, -0.587],
|
452 |
+
[-0.299, -0.299, 0.701]]) * factor
|
453 |
+
+ np.float32([[0.114], [0.587], [0.299]])
|
454 |
+
)
|
455 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
456 |
+
return out
|
457 |
+
|
458 |
+
|
459 |
+
def contrast_func(img, factor):
|
460 |
+
"""
|
461 |
+
same output as PIL.ImageEnhance.Contrast
|
462 |
+
"""
|
463 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
464 |
+
table = np.array([(
|
465 |
+
el - mean) * factor + mean
|
466 |
+
for el in range(256)
|
467 |
+
]).clip(0, 255).astype(np.uint8)
|
468 |
+
out = table[img]
|
469 |
+
return out
|
470 |
+
|
471 |
+
|
472 |
+
def brightness_func(img, factor):
|
473 |
+
'''
|
474 |
+
same output as PIL.ImageEnhance.Contrast
|
475 |
+
'''
|
476 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
477 |
+
out = table[img]
|
478 |
+
return out
|
479 |
+
|
480 |
+
|
481 |
+
def sharpness_func(img, factor):
|
482 |
+
'''
|
483 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
484 |
+
areas are same
|
485 |
+
'''
|
486 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
487 |
+
kernel[1][1] = 5
|
488 |
+
kernel /= 13
|
489 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
490 |
+
if factor == 0.0:
|
491 |
+
out = degenerate
|
492 |
+
elif factor == 1.0:
|
493 |
+
out = img
|
494 |
+
else:
|
495 |
+
out = img.astype(np.float32)
|
496 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
497 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
498 |
+
out = out.astype(np.uint8)
|
499 |
+
return out
|
500 |
+
|
501 |
+
|
502 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
503 |
+
H, W = img.shape[0], img.shape[1]
|
504 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
505 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
506 |
+
return out
|
507 |
+
|
508 |
+
|
509 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
510 |
+
'''
|
511 |
+
same output as PIL.Image.transform
|
512 |
+
'''
|
513 |
+
H, W = img.shape[0], img.shape[1]
|
514 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
515 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
516 |
+
return out
|
517 |
+
|
518 |
+
|
519 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
520 |
+
'''
|
521 |
+
same output as PIL.Image.transform
|
522 |
+
'''
|
523 |
+
H, W = img.shape[0], img.shape[1]
|
524 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
525 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
526 |
+
return out
|
527 |
+
|
528 |
+
|
529 |
+
def posterize_func(img, bits):
|
530 |
+
'''
|
531 |
+
same output as PIL.ImageOps.posterize
|
532 |
+
'''
|
533 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
534 |
+
return out
|
535 |
+
|
536 |
+
|
537 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
538 |
+
H, W = img.shape[0], img.shape[1]
|
539 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
540 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
541 |
+
return out
|
542 |
+
|
543 |
+
|
544 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
545 |
+
replace = np.array(replace, dtype=np.uint8)
|
546 |
+
H, W = img.shape[0], img.shape[1]
|
547 |
+
rh, rw = np.random.random(2)
|
548 |
+
pad_size = pad_size // 2
|
549 |
+
ch, cw = int(rh * H), int(rw * W)
|
550 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
551 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
552 |
+
out = img.copy()
|
553 |
+
out[x1:x2, y1:y2, :] = replace
|
554 |
+
return out
|
555 |
+
|
556 |
+
|
557 |
+
# level to args
|
558 |
+
def enhance_level_to_args(MAX_LEVEL):
|
559 |
+
def level_to_args(level):
|
560 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
561 |
+
return level_to_args
|
562 |
+
|
563 |
+
|
564 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
565 |
+
def level_to_args(level):
|
566 |
+
level = (level / MAX_LEVEL) * 0.3
|
567 |
+
if np.random.random() > 0.5:
|
568 |
+
level = -level
|
569 |
+
return (level, replace_value)
|
570 |
+
|
571 |
+
return level_to_args
|
572 |
+
|
573 |
+
|
574 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
575 |
+
def level_to_args(level):
|
576 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
577 |
+
if np.random.random() > 0.5:
|
578 |
+
level = -level
|
579 |
+
return (level, replace_value)
|
580 |
+
|
581 |
+
return level_to_args
|
582 |
+
|
583 |
+
|
584 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
585 |
+
def level_to_args(level):
|
586 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
587 |
+
return (level, replace_value)
|
588 |
+
|
589 |
+
return level_to_args
|
590 |
+
|
591 |
+
|
592 |
+
def solarize_level_to_args(MAX_LEVEL):
|
593 |
+
def level_to_args(level):
|
594 |
+
level = int((level / MAX_LEVEL) * 256)
|
595 |
+
return (level, )
|
596 |
+
return level_to_args
|
597 |
+
|
598 |
+
|
599 |
+
def none_level_to_args(level):
|
600 |
+
return ()
|
601 |
+
|
602 |
+
|
603 |
+
def posterize_level_to_args(MAX_LEVEL):
|
604 |
+
def level_to_args(level):
|
605 |
+
level = int((level / MAX_LEVEL) * 4)
|
606 |
+
return (level, )
|
607 |
+
return level_to_args
|
608 |
+
|
609 |
+
|
610 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
611 |
+
def level_to_args(level):
|
612 |
+
level = (level / MAX_LEVEL) * 30
|
613 |
+
if np.random.random() < 0.5:
|
614 |
+
level = -level
|
615 |
+
return (level, replace_value)
|
616 |
+
|
617 |
+
return level_to_args
|
618 |
+
|
619 |
+
|
620 |
+
func_dict = {
|
621 |
+
'Identity': identity_func,
|
622 |
+
'AutoContrast': autocontrast_func,
|
623 |
+
'Equalize': equalize_func,
|
624 |
+
'Rotate': rotate_func,
|
625 |
+
'Solarize': solarize_func,
|
626 |
+
'Color': color_func,
|
627 |
+
'Contrast': contrast_func,
|
628 |
+
'Brightness': brightness_func,
|
629 |
+
'Sharpness': sharpness_func,
|
630 |
+
'ShearX': shear_x_func,
|
631 |
+
'TranslateX': translate_x_func,
|
632 |
+
'TranslateY': translate_y_func,
|
633 |
+
'Posterize': posterize_func,
|
634 |
+
'ShearY': shear_y_func,
|
635 |
+
}
|
636 |
+
|
637 |
+
translate_const = 10
|
638 |
+
MAX_LEVEL = 10
|
639 |
+
replace_value = (128, 128, 128)
|
640 |
+
arg_dict = {
|
641 |
+
'Identity': none_level_to_args,
|
642 |
+
'AutoContrast': none_level_to_args,
|
643 |
+
'Equalize': none_level_to_args,
|
644 |
+
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
645 |
+
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
646 |
+
'Color': enhance_level_to_args(MAX_LEVEL),
|
647 |
+
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
648 |
+
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
649 |
+
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
650 |
+
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
651 |
+
'TranslateX': translate_level_to_args(
|
652 |
+
translate_const, MAX_LEVEL, replace_value
|
653 |
+
),
|
654 |
+
'TranslateY': translate_level_to_args(
|
655 |
+
translate_const, MAX_LEVEL, replace_value
|
656 |
+
),
|
657 |
+
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
658 |
+
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
659 |
+
}
|
660 |
+
|
661 |
+
|
662 |
+
class RandomAugment(object):
|
663 |
+
|
664 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
665 |
+
self.N = N
|
666 |
+
self.M = M
|
667 |
+
self.isPIL = isPIL
|
668 |
+
if augs:
|
669 |
+
self.augs = augs
|
670 |
+
else:
|
671 |
+
self.augs = list(arg_dict.keys())
|
672 |
+
|
673 |
+
def get_random_ops(self):
|
674 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
675 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
676 |
+
|
677 |
+
def __call__(self, img):
|
678 |
+
if self.isPIL:
|
679 |
+
img = np.array(img)
|
680 |
+
ops = self.get_random_ops()
|
681 |
+
for name, prob, level in ops:
|
682 |
+
if np.random.random() > prob:
|
683 |
+
continue
|
684 |
+
args = arg_dict[name](level)
|
685 |
+
img = func_dict[name](img, *args)
|
686 |
+
return img
|
687 |
+
|
688 |
+
|
689 |
+
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'):
|
690 |
+
if is_train:
|
691 |
+
t = [
|
692 |
+
RandomResizedCropAndInterpolation(
|
693 |
+
input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
|
694 |
+
transforms.RandomHorizontalFlip(),
|
695 |
+
]
|
696 |
+
if randaug:
|
697 |
+
t.append(
|
698 |
+
RandomAugment(
|
699 |
+
2, 7, isPIL=True,
|
700 |
+
augs=[
|
701 |
+
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
702 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
703 |
+
]))
|
704 |
+
t += [
|
705 |
+
transforms.ToTensor(),
|
706 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
707 |
+
]
|
708 |
+
t = transforms.Compose(t)
|
709 |
+
else:
|
710 |
+
t = transforms.Compose([
|
711 |
+
transforms.Resize((input_size, input_size),
|
712 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
713 |
+
transforms.ToTensor(),
|
714 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
|
715 |
+
])
|
716 |
+
|
717 |
+
return t
|
718 |
+
|
719 |
+
|
720 |
+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
721 |
+
with open(filename, "wb") as fh:
|
722 |
+
with urllib.request.urlopen(
|
723 |
+
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
724 |
+
) as response:
|
725 |
+
with tqdm(total=response.length) as pbar:
|
726 |
+
for chunk in iter(lambda: response.read(chunk_size), ""):
|
727 |
+
if not chunk:
|
728 |
+
break
|
729 |
+
pbar.update(chunk_size)
|
730 |
+
fh.write(chunk)
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|