REILX commited on
Commit
5ea6514
1 Parent(s): 21ecb5d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -3
README.md CHANGED
@@ -1,3 +1,100 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - REILX/text-description-of-the-meme
5
+ language:
6
+ - zh
7
+ tags:
8
+ - llava
9
+ - Qwen2
10
+ - txtimage-to-txt
11
+ - lora
12
+ ---
13
+
14
+ ### 模型 llava-Qwen2-7B-Instruct-CLIP-ZH 增强中文文字识别能力和表情包内涵识别能力
15
+ 1. 模型结构:</br>
16
+ llava-Qwen2-7B-Instruct-CLIP-ZH = Qwen/Qwen2-7B-Instruct + multi_modal_projector + openai/clip-vit-large-patch14-336</br>
17
+
18
+ 2. 微调模块
19
+ - vision_tower和language_model的q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj部分进行lora训练</br>
20
+ - mmp层全量训练</br>
21
+
22
+ 3. 微调参数</br>
23
+ lora_r=32, lora_alpha=64,num_train_epochs=5,per_device_train_batch_size=1,gradient_accumulation_steps=8,high_lr=1e-3,low_lr=2e-5,model_max_length=2048
24
+
25
+ 4. 数据集</br>
26
+ 使用gemini-1.5-pro, gemini-1.5-flash, yi-vision, gpt4o,claude-3.5-sonnet模型描述emo-visual-data和ChineseBQB数据集。</br>
27
+ 文本描述信息通过[text-description-of-the-meme](https://huggingface.co/datasets/REILX/text-description-of-the-meme) 下载</br>
28
+ 图像可通过[emo-visual-data](https://github.com/LLM-Red-Team/emo-visual-data), [ChineseBQB](https://github.com/zhaoolee/ChineseBQB)下载</br>
29
+ 图片数据总量1.8G,约10835张中文表情包图片。文字总量42Mb,约24332个图像文本对描述信息。
30
+
31
+ 5. 效果展示
32
+ ![](./images/llava-qwen2-lora-01.JPG)
33
+ ![](./images/llava-qwen2-lora-02.JPG)
34
+ ![](./images/llava-qwen2-lora-03.JPG)
35
+
36
+
37
+ 5. 代码</br>
38
+ 合并模型代码,合并模型之后将add_tokens.json,merge.txt,preprocessor_config.json,specital_token_map.json,tokenizer.json,vocab.json文件复制到"/保存的完整模型路径"。
39
+ ```python
40
+ import torch
41
+ from peft import PeftModel, LoraConfig
42
+ from transformers import LlavaForConditionalGeneration
43
+ model_name = "/替换为你的基础模型路径"
44
+ LORA_R = 32
45
+ LORA_ALPHA = 64
46
+ LORA_DROPOUT = 0.05
47
+ TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
48
+ lora_config = LoraConfig(
49
+ r=LORA_R,
50
+ lora_alpha=LORA_ALPHA,
51
+ target_modules=TARGET_MODULES,
52
+ lora_dropout=LORA_DROPOUT,
53
+ bias="none",
54
+ task_type="CAUSAL_LM",
55
+ modules_to_save=["multi_modal_projector"],
56
+ )
57
+ model = LlavaForConditionalGeneration.from_pretrained(model_name)
58
+ model = PeftModel.from_pretrained(model, "/替换为你的lora模型路径", config=lora_config, adapter_name='lora')
59
+
60
+ model.cpu()
61
+ model.eval()
62
+ base_model = model.get_base_model()
63
+ base_model.eval()
64
+ model.merge_and_unload()
65
+ base_model.save_pretrained("/保存的完整模型路径")
66
+ ```
67
+
68
+ 推理代码
69
+ ```python
70
+ from transformers import LlavaForConditionalGeneration, AutoProcessor
71
+ import torch
72
+ from PIL import Image
73
+
74
+ raw_model_name_or_path = "/保存的完整模型路径"
75
+ model = LlavaForConditionalGeneration.from_pretrained(raw_model_name_or_path, device_map="cuda:0", torch_dtype=torch.bfloat16)
76
+ processor = AutoProcessor.from_pretrained(raw_model_name_or_path)
77
+ model.eval()
78
+
79
+ def build_model_input(model, processor):
80
+ messages = [
81
+ {"role": "system", "content": "You are a helpful assistant."},
82
+ {"role": "user", "content": "<image>\n 使用中文描述图片中的信息"}
83
+ ]
84
+ prompt = processor.tokenizer.apply_chat_template(
85
+ messages, tokenize=False, add_generation_prompt=True
86
+ )
87
+ image = Image.open("01.PNG")
88
+ inputs = processor(text=prompt, images=image, return_tensors="pt", return_token_type_ids=False)
89
+
90
+ for tk in inputs.keys():
91
+ inputs[tk] = inputs[tk].to(model.device)
92
+ generate_ids = model.generate(**inputs, max_new_tokens=200)
93
+
94
+ generate_ids = [
95
+ oid[len(iids):] for oid, iids in zip(generate_ids, inputs.input_ids)
96
+ ]
97
+ gen_text = processor.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
98
+ return gen_text
99
+ build_model_input(model, processor)
100
+ ```