Update README.md
Browse files
README.md
CHANGED
@@ -7,7 +7,7 @@ tags:
|
|
7 |
- lora
|
8 |
- generated_from_trainer
|
9 |
model-index:
|
10 |
-
- name:
|
11 |
results: []
|
12 |
---
|
13 |
|
@@ -18,17 +18,97 @@ should probably proofread and complete it, then remove this comment. -->
|
|
18 |
|
19 |
This model is a fine-tuned version of [google/gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it) on the prm_dpo dataset.
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
## Training procedure
|
34 |
|
@@ -48,9 +128,6 @@ The following hyperparameters were used during training:
|
|
48 |
- lr_scheduler_type: linear
|
49 |
- num_epochs: 1.0
|
50 |
|
51 |
-
### Training results
|
52 |
-
|
53 |
-
|
54 |
|
55 |
### Framework versions
|
56 |
|
|
|
7 |
- lora
|
8 |
- generated_from_trainer
|
9 |
model-index:
|
10 |
+
- name: PPRM-gemma-2-2b-it
|
11 |
results: []
|
12 |
---
|
13 |
|
|
|
18 |
|
19 |
This model is a fine-tuned version of [google/gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it) on the prm_dpo dataset.
|
20 |
|
21 |
+
# Citation
|
22 |
+
```
|
23 |
+
@article{zhang2024llama,
|
24 |
+
title={LLaMA-Berry: Pairwise Optimization for O1-like Olympiad-Level Mathematical Reasoning},
|
25 |
+
author={Zhang, Di and Wu, Jianbo and Lei, Jingdi and Che, Tong and Li, Jiatong and Xie, Tong and Huang, Xiaoshui and Zhang, Shufei and Pavone, Marco and Li, Yuqiang and others},
|
26 |
+
journal={arXiv preprint arXiv:2410.02884},
|
27 |
+
year={2024}
|
28 |
+
}
|
29 |
+
|
30 |
+
@article{zhang2024accessing,
|
31 |
+
title={Accessing GPT-4 level Mathematical Olympiad Solutions via Monte Carlo Tree Self-refine with LLaMa-3 8B},
|
32 |
+
author={Zhang, Di and Li, Jiatong and Huang, Xiaoshui and Zhou, Dongzhan and Li, Yuqiang and Ouyang, Wanli},
|
33 |
+
journal={arXiv preprint arXiv:2406.07394},
|
34 |
+
year={2024}
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
```
|
39 |
+
|
40 |
+
## Model usage
|
41 |
+
|
42 |
+
`server.py`
|
43 |
+
```
|
44 |
+
import json
|
45 |
+
from fastapi import FastAPI, HTTPException
|
46 |
+
from pydantic import BaseModel
|
47 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
48 |
+
from peft import PeftModel
|
49 |
+
import torch
|
50 |
+
|
51 |
+
# Initialize FastAPI
|
52 |
+
app = FastAPI()
|
53 |
+
|
54 |
+
# Device configuration
|
55 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
+
|
57 |
+
# Model and tokenizer loading (as you provided)
|
58 |
+
model_name = "google/gemma-2-2b-it"
|
59 |
+
|
60 |
+
lora_checkpoint_path = "qq8933/PPRM-gemma-2-2b-it"
|
61 |
+
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
63 |
+
base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map='cuda')
|
64 |
+
model = PeftModel.from_pretrained(base_model, lora_checkpoint_path, device_map='cuda')
|
65 |
+
|
66 |
+
yes_token_id = tokenizer.convert_tokens_to_ids("yes")
|
67 |
+
no_token_id = tokenizer.convert_tokens_to_ids("no")
|
68 |
+
|
69 |
+
# Request model
|
70 |
+
class InputRequest(BaseModel):
|
71 |
+
text: str
|
72 |
+
|
73 |
+
# Predict function
|
74 |
+
def predict(qeustion,answer_1,answer_2):
|
75 |
+
prompt_template = """Problem:\n\n{}\n\nFirst Answer:\n\n{}\n\nSecond Answer:\n\n{}\n\nIs First Answer better than Second Answer?\n\n"""
|
76 |
+
input_text = prompt_template.format(qeustion,answer_1,answer_2)
|
77 |
+
input_text = tokenizer.apply_chat_template(
|
78 |
+
[{'role': 'user', 'content': input_text}], tokenize=False, add_generation_prompt=True
|
79 |
+
)
|
80 |
+
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
81 |
+
with torch.no_grad():
|
82 |
+
generated_outputs = model.generate(
|
83 |
+
**inputs, max_new_tokens=2, output_scores=True, return_dict_in_generate=True
|
84 |
+
)
|
85 |
+
scores = generated_outputs.scores
|
86 |
+
first_token_logits = scores[0]
|
87 |
+
yes_logit = first_token_logits[0, yes_token_id].item()
|
88 |
+
no_logit = first_token_logits[0, no_token_id].item()
|
89 |
+
|
90 |
+
return {
|
91 |
+
"yes_logit": yes_logit,
|
92 |
+
"no_logit": no_logit,
|
93 |
+
"logit_difference": yes_logit - no_logit
|
94 |
+
}
|
95 |
+
|
96 |
+
# Define API endpoint
|
97 |
+
@app.post("/predict")
|
98 |
+
async def get_prediction(input_request: InputRequest):
|
99 |
+
payload = json.loads(input_request.text)
|
100 |
+
qeustion,answer_1,answer_2 = payload['qeustion'],payload['answer_1'],payload['answer_2']
|
101 |
+
try:
|
102 |
+
result = predict(qeustion,answer_1,answer_2)
|
103 |
+
return result
|
104 |
+
except Exception as e:
|
105 |
+
raise HTTPException(status_code=500, detail=str(e))
|
106 |
+
|
107 |
+
```
|
108 |
+
|
109 |
+
```
|
110 |
+
uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1
|
111 |
+
```
|
112 |
|
113 |
## Training procedure
|
114 |
|
|
|
128 |
- lr_scheduler_type: linear
|
129 |
- num_epochs: 1.0
|
130 |
|
|
|
|
|
|
|
131 |
|
132 |
### Framework versions
|
133 |
|