Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,100 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- REILX/text-description-of-the-meme
|
5 |
+
language:
|
6 |
+
- zh
|
7 |
+
tags:
|
8 |
+
- llava
|
9 |
+
- Qwen2
|
10 |
+
- txtimage-to-txt
|
11 |
+
- lora
|
12 |
+
---
|
13 |
+
|
14 |
+
### 模型 llava-Qwen2-7B-Instruct-CLIP-ZH 增强中文文字识别能力和表情包内涵识别能力
|
15 |
+
1. 模型结构:</br>
|
16 |
+
llava-Qwen2-7B-Instruct-CLIP-ZH = Qwen/Qwen2-7B-Instruct + multi_modal_projector + openai/clip-vit-large-patch14-336</br>
|
17 |
+
|
18 |
+
2. 微调模块
|
19 |
+
- vision_tower和language_model的q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj部分进行lora训练</br>
|
20 |
+
- mmp层全量训练</br>
|
21 |
+
|
22 |
+
3. 微调参数</br>
|
23 |
+
lora_r=32, lora_alpha=64,num_train_epochs=5,per_device_train_batch_size=1,gradient_accumulation_steps=8,high_lr=1e-3,low_lr=2e-5,model_max_length=2048
|
24 |
+
|
25 |
+
4. 数据集</br>
|
26 |
+
使用gemini-1.5-pro, gemini-1.5-flash, yi-vision, gpt4o,claude-3.5-sonnet模型描述emo-visual-data和ChineseBQB数据集。</br>
|
27 |
+
文本描述信息通过[text-description-of-the-meme](https://huggingface.co/datasets/REILX/text-description-of-the-meme) 下载</br>
|
28 |
+
图像可通过[emo-visual-data](https://github.com/LLM-Red-Team/emo-visual-data), [ChineseBQB](https://github.com/zhaoolee/ChineseBQB)下载</br>
|
29 |
+
图片数据总量1.8G,约10835张中文表情包图片。文字总量42Mb,约24332个图像文本对描述信息。
|
30 |
+
|
31 |
+
5. 效果展示
|
32 |
+
![](./images/llava-qwen2-lora-01.JPG)
|
33 |
+
![](./images/llava-qwen2-lora-02.JPG)
|
34 |
+
![](./images/llava-qwen2-lora-03.JPG)
|
35 |
+
|
36 |
+
|
37 |
+
5. 代码</br>
|
38 |
+
合并模型代码,合并模型之后将add_tokens.json,merge.txt,preprocessor_config.json,specital_token_map.json,tokenizer.json,vocab.json文件复制到"/保存的完整模型路径"。
|
39 |
+
```python
|
40 |
+
import torch
|
41 |
+
from peft import PeftModel, LoraConfig
|
42 |
+
from transformers import LlavaForConditionalGeneration
|
43 |
+
model_name = "/替换为你的基础模型路径"
|
44 |
+
LORA_R = 32
|
45 |
+
LORA_ALPHA = 64
|
46 |
+
LORA_DROPOUT = 0.05
|
47 |
+
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
48 |
+
lora_config = LoraConfig(
|
49 |
+
r=LORA_R,
|
50 |
+
lora_alpha=LORA_ALPHA,
|
51 |
+
target_modules=TARGET_MODULES,
|
52 |
+
lora_dropout=LORA_DROPOUT,
|
53 |
+
bias="none",
|
54 |
+
task_type="CAUSAL_LM",
|
55 |
+
modules_to_save=["multi_modal_projector"],
|
56 |
+
)
|
57 |
+
model = LlavaForConditionalGeneration.from_pretrained(model_name)
|
58 |
+
model = PeftModel.from_pretrained(model, "/替换为你的lora模型路径", config=lora_config, adapter_name='lora')
|
59 |
+
|
60 |
+
model.cpu()
|
61 |
+
model.eval()
|
62 |
+
base_model = model.get_base_model()
|
63 |
+
base_model.eval()
|
64 |
+
model.merge_and_unload()
|
65 |
+
base_model.save_pretrained("/保存的完整模型路径")
|
66 |
+
```
|
67 |
+
|
68 |
+
推理代码
|
69 |
+
```python
|
70 |
+
from transformers import LlavaForConditionalGeneration, AutoProcessor
|
71 |
+
import torch
|
72 |
+
from PIL import Image
|
73 |
+
|
74 |
+
raw_model_name_or_path = "/保存的完整模型路径"
|
75 |
+
model = LlavaForConditionalGeneration.from_pretrained(raw_model_name_or_path, device_map="cuda:0", torch_dtype=torch.bfloat16)
|
76 |
+
processor = AutoProcessor.from_pretrained(raw_model_name_or_path)
|
77 |
+
model.eval()
|
78 |
+
|
79 |
+
def build_model_input(model, processor):
|
80 |
+
messages = [
|
81 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
82 |
+
{"role": "user", "content": "<image>\n 使用中文描述图片中的信息"}
|
83 |
+
]
|
84 |
+
prompt = processor.tokenizer.apply_chat_template(
|
85 |
+
messages, tokenize=False, add_generation_prompt=True
|
86 |
+
)
|
87 |
+
image = Image.open("01.PNG")
|
88 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt", return_token_type_ids=False)
|
89 |
+
|
90 |
+
for tk in inputs.keys():
|
91 |
+
inputs[tk] = inputs[tk].to(model.device)
|
92 |
+
generate_ids = model.generate(**inputs, max_new_tokens=200)
|
93 |
+
|
94 |
+
generate_ids = [
|
95 |
+
oid[len(iids):] for oid, iids in zip(generate_ids, inputs.input_ids)
|
96 |
+
]
|
97 |
+
gen_text = processor.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
98 |
+
return gen_text
|
99 |
+
build_model_input(model, processor)
|
100 |
+
```
|