mjbuehler commited on
Commit
d8d122f
1 Parent(s): ee215a7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -0
README.md CHANGED
@@ -119,6 +119,100 @@ The image below shows reproductions of two representative pages of the scientifi
119
 
120
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/qHURSBRWEDgHy4o56escN.png)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  ## Citation
123
 
124
  Please cite as:
 
119
 
120
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/qHURSBRWEDgHy4o56escN.png)
121
 
122
+ ## Fine-tuning
123
+
124
+
125
+ Load base model
126
+
127
+ ```python
128
+ model_id = "microsoft/Phi-3-vision-128k-instruct"
129
+
130
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto")
131
+
132
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
133
+ ```
134
+
135
+ Define FT_repo_id to push on HF hub/save model:
136
+ ```
137
+ FT_repo_id='xxxxx/' #<repo_ID>
138
+ ```
139
+
140
+ ```
141
+ from datasets import load_dataset
142
+
143
+ train_dataset = load_dataset("lamm-mit/Cephalo-Wikipedia-Materials", split="train")
144
+ ```
145
+
146
+ ```python
147
+ import random
148
+
149
+ class MyDataCollator:
150
+ def __init__(self, processor):
151
+ self.processor = processor
152
+
153
+ def __call__(self, examples):
154
+ texts = []
155
+ images = []
156
+ for example in examples:
157
+ image = example["image"]
158
+ question = example["query"]
159
+ answer = example["answer"]
160
+ messages = [ {
161
+ "role": "user", "content": '<|image_1|>\n'+question},
162
+ {"role": "assistant", "content": f"{answer}"}, ]
163
+
164
+ text = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
165
+
166
+ images.append(image)
167
+
168
+ batch = processor(text=text, images=[image], return_tensors="pt", padding=True
169
+
170
+ labels = batch["input_ids"].clone()
171
+ labels[labels <0] = -100
172
+
173
+ batch["labels"] = labels
174
+
175
+ return batch
176
+
177
+ data_collator = MyDataCollator(processor)
178
+ ```
179
+ Then set up trainer, and train:
180
+ ```python
181
+ from transformers import TrainingArguments, Trainer
182
+
183
+ optim = "paged_adamw_8bit"
184
+
185
+ training_args = TrainingArguments(
186
+ num_train_epochs=2,
187
+ per_device_train_batch_size=1,
188
+ #per_device_eval_batch_size=4,
189
+ gradient_accumulation_steps=4,
190
+ warmup_steps=250,
191
+ learning_rate=1e-5,
192
+ weight_decay=0.01,
193
+ logging_steps=25,
194
+ output_dir="output_training",
195
+ optim=optim,
196
+ save_strategy="steps",
197
+ save_steps=1000,
198
+ save_total_limit=16,
199
+ #fp16=True,
200
+ bf16=True,
201
+ push_to_hub_model_id=FT_repo_id,
202
+ remove_unused_columns=False,
203
+ report_to="none",
204
+ )
205
+
206
+ trainer = Trainer(
207
+ model=model,
208
+ args=training_args,
209
+ data_collator=data_collator,
210
+ train_dataset=train_dataset,
211
+ )
212
+
213
+ trainer.train()
214
+ ```
215
+
216
  ## Citation
217
 
218
  Please cite as: