Blair Yang commited on
Commit
b11f272
1 Parent(s): dfba357
Sample.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import json
5
  from Config import *
6
  import pandas as pd
 
7
 
8
  def format_card_str(card):
9
  entries = []
@@ -76,7 +77,40 @@ def process_for_display(card_lst, qa):
76
  def select_entry(qa_entry, card_lst):
77
  # TODO: Automatically select most relevant criterion.
78
  # PLACE HOLDER, RETURN THE WHOEL THING
79
- return '\n'.join(card_lst[:2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def sample_card(dataset='', topic='', model='', card_cnt=2):
 
4
  import json
5
  from Config import *
6
  import pandas as pd
7
+ from models import HFAPIModel
8
 
9
  def format_card_str(card):
10
  entries = []
 
77
  def select_entry(qa_entry, card_lst):
78
  # TODO: Automatically select most relevant criterion.
79
  # PLACE HOLDER, RETURN THE WHOEL THING
80
+
81
+ # if False:
82
+ # return '\n'.join(card_lst[:2])
83
+
84
+ system_prompt = '''
85
+ Your task is to effectively condense the essential details from the student's evaluation card that are most relevant to predicting the correctness of their answer to a question.
86
+ Limit your paraphrase to 100-150 words, focusing on distilling the key observations and outcomes that are directly pertinent to the inquiry.
87
+ It's crucial to present an informative, unbiased summary that retains the integrity of the original card's information.
88
+ Your goal is to craft a paraphrase that enhances the user's ability to accurately gauge the student's response, by emphasizing relevant insights and conclusions without altering the core facts.
89
+ '''
90
+
91
+ card_str = '\n'.join(card_lst)
92
+ prompt = f'''
93
+ ## Question:
94
+ {qa_entry}
95
+
96
+ ## Evaluation Card:
97
+ {card_str}
98
+
99
+ Again, your task is not to answer the question, but summarize the student's ability in answering the question!
100
+ '''
101
+
102
+ # Mistral-7B-Instruct-v0.2
103
+ # mistralai/Mixtral-8x7B-Instruct-v0.1
104
+
105
+ model = HFAPIModel(system_prompt=system_prompt,
106
+ model_name='mistralai/Mistral-7B-Instruct-v0.2')
107
+
108
+
109
+ response = model(prompt)
110
+ print(response)
111
+ del model
112
+ return response
113
+
114
 
115
 
116
  def sample_card(dataset='', topic='', model='', card_cnt=2):
__pycache__/Sample.cpython-311.pyc CHANGED
Binary files a/__pycache__/Sample.cpython-311.pyc and b/__pycache__/Sample.cpython-311.pyc differ
 
__pycache__/models.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
models.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import re
4
+ from collections import defaultdict
5
+ from typing import Any, List, Dict, Callable, Union, Optional
6
+ # from vllm import LLM, SamplingParams
7
+
8
+ import regex
9
+ import numpy as np
10
+ from huggingface_hub import InferenceClient
11
+ from tqdm import tqdm
12
+
13
+ # from config import *
14
+
15
+ ROLE_SYSTEM = 'system'
16
+ ROLE_USER = 'user'
17
+ ROLE_ASSISTANT = 'assistant'
18
+
19
+
20
+
21
+ CHAT_FORMATS = {
22
+ "mistralai": "<s>[INST] {prompt} [/INST]",
23
+ "openchat": "GPT4 User: {prompt}<|end_of_turn|>GPT4 Assistant:",
24
+ "meta-llama": """[INST] <<SYS>>
25
+ You answer questions directly.
26
+ <</SYS>>
27
+ {prompt}[/INST]""",
28
+ "mosaicml": """<|im_start|>system
29
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>
30
+ <|im_start|>user
31
+ {prompt}<|im_end|>
32
+ <|im_start|>assistant""",
33
+ "lmsys": "USER: {prompt}\nASSISTANT:",
34
+ }
35
+
36
+
37
+ LLAMA_TEMPLATE = """<s>[INST] <<SYS>>
38
+ {system_prompt}
39
+ <</SYS>>
40
+
41
+ {user_message} [/INST]"""
42
+
43
+ MISTRAL_TEMPLATE = """<s>[INST] <<SYS>>
44
+ {system_prompt}
45
+ <</SYS>> {user_message} [/INST]"""
46
+
47
+ YI_34B_TEMPLATE = """<|im_start|>system
48
+ {system_prompt}<|im_end|>
49
+ <|im_start|>user
50
+ {user_message}<|im_end|>
51
+ <|im_start|>assistant
52
+ """
53
+
54
+
55
+ def extract_json(text: str) -> Dict:
56
+ # json_string_match = re.search(r"json\s+(.+?)\s+", text, re.DOTALL)
57
+
58
+ # Assume it's goind to be like: "Guess": "A" or "Guess": "B"
59
+
60
+ # Now it's either true or false
61
+
62
+ # print(text)
63
+ text = text.replace('\\', '\\\\')
64
+
65
+ try:
66
+ rslt = json.loads(text)
67
+ except Exception as e:
68
+ # print(e)
69
+ # print(text)
70
+ rslt = None
71
+ return rslt
72
+
73
+
74
+ def mixtral_prompt_formatter(messages: List[Dict[str, str]]) -> str:
75
+ """
76
+ <s>[INST] <<SYS>>
77
+ {system_prompt}
78
+ <</SYS>>
79
+ {user_prompt} [/INST]
80
+ """
81
+ assert len(messages) >= 2 # must be at least a system and a user
82
+ r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n{messages[1]["content"]} [/INST]'
83
+ for msg in messages[2:]:
84
+ role, content = msg['role'], msg['content']
85
+ if role == ROLE_SYSTEM:
86
+ assert ValueError
87
+ elif role == ROLE_USER:
88
+ if r.endswith('</s>'):
89
+ r += '<s>'
90
+ r += f'[INST] {content} [/INST]'
91
+ elif role == ROLE_ASSISTANT:
92
+ r += f'{content}</s>'
93
+ else:
94
+ raise ValueError
95
+ return r
96
+
97
+
98
+ def llama_prompt_formatter(messages: List[Dict[str, str]]) -> str:
99
+ """
100
+ <s>[INST] <<SYS>>
101
+ {system_prompt}
102
+ <</SYS>>
103
+
104
+ {user_message} [/INST]
105
+ """
106
+ assert len(messages) >= 2 # must be at least a system and a user
107
+ r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n\n{messages[1]["content"]} [/INST]'
108
+ for msg in messages[2:]:
109
+ role, content = msg['role'], msg['content']
110
+ if role == ROLE_SYSTEM:
111
+ assert ValueError
112
+ elif role == ROLE_USER:
113
+ if r.endswith('</s>'):
114
+ r += '<s>'
115
+ r += f'[INST] {content} [/INST]'
116
+ elif role == ROLE_ASSISTANT:
117
+ r += f'{content}</s>'
118
+ else:
119
+ raise ValueError
120
+ return r
121
+
122
+
123
+ def yi_prompt_formatter(messages: List[Dict[str, str]]) -> str:
124
+ """
125
+ <|im_start|>system
126
+ {system_prompt}<|im_end|>
127
+ <|im_start|>user
128
+ {user_message}<|im_end|>
129
+ <|im_start|>assistant
130
+ """
131
+ assert len(messages) >= 2 # must be at least a system and a user
132
+ r = f'<|im_start|>system\n{messages[0]["content"]}<|im_end|>\n<|im_start|>user\n{messages[1]["content"]}<|im_end|>\n'
133
+ for i in range(2, len(messages)):
134
+ msg = messages[i]
135
+ role, content = msg['role'], msg['content']
136
+ if role == ROLE_SYSTEM:
137
+ assert ValueError
138
+ elif role == ROLE_USER:
139
+ r += f'<|im_start|>user\n{content}<|im_end|>\n'
140
+ if i == len(messages) - 1:
141
+ r += '<|im_start|>assistant\n'
142
+ elif role == ROLE_ASSISTANT:
143
+ r += f'<|im_start|>assistant\n{content}<|im_end|>\n'
144
+ else:
145
+ raise ValueError
146
+ return r
147
+
148
+
149
+ def find_first_valid_json(s) -> Optional[Dict]:
150
+ s = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', lambda m: m.group(0)[1:], s)
151
+ for i in range(len(s)):
152
+ if s[i] != '{':
153
+ continue
154
+ for j in range(i + 1, len(s) + 1):
155
+ if s[j - 1] != '}':
156
+ continue
157
+ try:
158
+ potential_json = s[i:j]
159
+ json_obj = json.loads(potential_json, strict=False)
160
+ return json_obj # Return the first valid JSON object found
161
+ except json.JSONDecodeError:
162
+ pass # Continue searching if JSON decoding fails
163
+ return None # Return None if no valid JSON object is found
164
+
165
+
166
+ class HFAPIModel:
167
+ model_name: str
168
+ messages: List[Dict[str, str]]
169
+ system_prompt: str
170
+ formatter: Callable[[List[Dict[str, str]]], str]
171
+
172
+ def __init__(self, system_prompt: str, model_name: str) -> None:
173
+ self.system_prompt = system_prompt
174
+ self.model_name = model_name
175
+ if 'llama' in model_name:
176
+ self.formatter = llama_prompt_formatter
177
+ elif 'mistral' in model_name:
178
+ self.formatter = mixtral_prompt_formatter
179
+ else:
180
+ raise NotImplementedError
181
+ self.messages = [
182
+ {'role': ROLE_SYSTEM, 'content': system_prompt}
183
+ ]
184
+
185
+ def __call__(self, user_prompt: str, use_json: bool = False,
186
+ temperature: float = 0, timeout: float = None,
187
+ cache: bool = True) -> Union[str, Dict]:
188
+ self.add_message(ROLE_USER, user_prompt)
189
+ response = self.get_response(temperature, use_json, timeout, cache)
190
+ self.add_message(ROLE_ASSISTANT, response)
191
+ return response
192
+
193
+ def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> Union[str, Dict]:
194
+ """
195
+ Returns the model's response.
196
+ If use_json = True, will try its best to return a json dict, but not guaranteed.
197
+ If we cannot parse the JSON, we will return the response string directly.
198
+ """
199
+ client = InferenceClient(self.model_name, timeout=timeout)
200
+ if not cache:
201
+ client.headers["x-use-cache"] = "0"
202
+ # print(self.formatter(self.messages)) # debug
203
+ r = client.text_generation(self.formatter(self.messages),
204
+ do_sample=temperature > 0,
205
+ temperature=temperature if temperature > 0 else None,
206
+ max_new_tokens=512)
207
+ if use_json:
208
+ obj = find_first_valid_json(r)
209
+ if obj is not None:
210
+ return obj
211
+ return r
212
+
213
+ def add_message(self, role: str, message: str):
214
+ self.messages.append({'role': role, 'content': message})
215
+
216
+
217
+ if __name__ == '__main__':
218
+ # model = GPTModel(system_prompt='You are an AI developed by OpenAI.', model_name=GPT_4_MODEL_NAME)
219
+ model = HFAPIModel(system_prompt='You are a helpful assistant.', model_name='mistralai/Mixtral-8x7B-Instruct-v0.1')
220
+ print(model('Who are you?'))
responses/mmlu/high_school_physics/response.csv CHANGED
@@ -1 +1,3 @@
1
- index,model,reasoning,correctness,confidence
 
 
 
1
+ index,model,reasoning,correctness,confidence213,Llama-2-70b-chat-hf,,False,0
2
+ 213,Llama-2-70b-chat-hf,,False,0
3
+ 44,Mixtral-8x7B-Instruct-v0.1,,False,0