Swain commited on
Commit
987a0dc
·
unverified ·
2 Parent(s): 3314d4b 09b80dc

Merge branch 'main' into main

Browse files
README.md CHANGED
@@ -29,13 +29,9 @@ QUESTION_LANG=cn QUESTION_LLM='chatgpt' QUESTION_LLM_KEY=<your API key> python3
29
  ```shell
30
  QUESTION_LANG=en QUESTION_LLM='chatgpt' QUESTION_LLM_KEY=<your API key> python3 -u app.py
31
  ```
32
- ### LLaMA2-7b + 中文
33
  ```shell
34
- QUESTION_LANG=cn QUESTION_LLM='llama2-7b' python3 -u app.py
35
- ```
36
- ### LLaMA2-7b + 英文
37
- ```shell
38
- QUESTION_LANG=en QUESTION_LLM='llama2-7b' python3 -u app.py
39
  ```
40
  ## :technologist: 为什么制作这个游戏
41
 
@@ -57,9 +53,9 @@ QUESTION_LANG=en QUESTION_LLM='llama2-7b' python3 -u app.py
57
  - [x] 支持自定义关卡
58
  - [ ] 在线试玩链接
59
  - [ ] Hugging Face Space 链接
60
- - [ ] 支持LLaMA2-7B(英文)
61
- - [ ] 支持Mistral-7B(英文)
62
  - [ ] 支持Baichuan2-7B(中文)
 
63
  - [ ] LLM 推理速度优化
64
 
65
 
 
29
  ```shell
30
  QUESTION_LANG=en QUESTION_LLM='chatgpt' QUESTION_LLM_KEY=<your API key> python3 -u app.py
31
  ```
32
+ ### Mistral-7B-Instruct-v0.1 + 英文
33
  ```shell
34
+ QUESTION_LANG=en QUESTION_LLM='mistral-7b' python3 -u app.py
 
 
 
 
35
  ```
36
  ## :technologist: 为什么制作这个游戏
37
 
 
53
  - [x] 支持自定义关卡
54
  - [ ] 在线试玩链接
55
  - [ ] Hugging Face Space 链接
56
+ - [x] 支持Mistral-7B-Instruct-v0.1(英文)
 
57
  - [ ] 支持Baichuan2-7B(中文)
58
+ - [ ] 支持LLaMA2-7B(英文)
59
  - [ ] LLM 推理速度优化
60
 
61
 
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import uuid
 
3
 
4
  import gradio as gr
5
 
@@ -7,14 +8,20 @@ from llmriddles.questions import QuestionExecutor
7
  from llmriddles.questions import list_ordered_questions
8
 
9
  _QUESTION_IDS = {}
 
10
  _QUESTIONS = list_ordered_questions()
11
  _LANG = os.environ.get('QUESTION_LANG', 'cn')
12
  assert _LANG in ['cn', 'en'], _LANG
13
  _LLM = os.environ.get('QUESTION_LLM', 'chatgpt')
14
- assert _LLM in ['chatgpt', 'llama2-7b'], _LLM
15
  _LLM_KEY = os.environ.get('QUESTION_LLM_KEY', None)
 
16
 
17
  if _LANG == "cn":
 
 
 
 
18
  title = "完蛋!我被 LLM 拿捏了"
19
  requirement_ph = """
20
  欢迎来到 LLM Riddles!
@@ -122,7 +129,7 @@ if __name__ == '__main__':
122
  gr_question = gr.TextArea(placeholder=question_ph, label=question_label)
123
  gr_api_key = gr.Text(placeholder=api_ph, label=api_label, type='password', visible=_need_api_key())
124
  with gr.Row():
125
- gr_submit = gr.Button(submit_label, interactive=True)
126
  gr_next = gr.Button(next_label)
127
 
128
  with gr.Column():
@@ -134,8 +141,11 @@ if __name__ == '__main__':
134
 
135
 
136
  def _next_question(uuid_):
 
137
  if not uuid_:
138
  uuid_ = str(uuid.uuid4())
 
 
139
  global _QUESTION_IDS
140
  _qid = _QUESTION_IDS.get(uuid_, -1)
141
  _qid += 1
@@ -143,8 +153,9 @@ if __name__ == '__main__':
143
 
144
  if _qid >= len(_QUESTIONS):
145
  del _QUESTION_IDS[uuid_]
 
146
  return game_cleared_label, '', '', {}, '', \
147
- gr.Button(submit_label, interactive=True), \
148
  gr.Button(try_again_label, interactive=True), \
149
  ''
150
  else:
 
1
  import os
2
  import uuid
3
+ import logging
4
 
5
  import gradio as gr
6
 
 
8
  from llmriddles.questions import list_ordered_questions
9
 
10
  _QUESTION_IDS = {}
11
+ count = 0
12
  _QUESTIONS = list_ordered_questions()
13
  _LANG = os.environ.get('QUESTION_LANG', 'cn')
14
  assert _LANG in ['cn', 'en'], _LANG
15
  _LLM = os.environ.get('QUESTION_LLM', 'chatgpt')
16
+ assert _LLM in ['chatgpt', 'mistral-7b'], _LLM
17
  _LLM_KEY = os.environ.get('QUESTION_LLM_KEY', None)
18
+ _DEBUG = os.environ.get('DEBUG', 'false').lower() == 'true'
19
 
20
  if _LANG == "cn":
21
+ if _DEBUG:
22
+ logging.getLogger().setLevel(logging.INFO)
23
+ else:
24
+ logging.getLogger().setLevel(logging.WARNING)
25
  title = "完蛋!我被 LLM 拿捏了"
26
  requirement_ph = """
27
  欢迎来到 LLM Riddles!
 
129
  gr_question = gr.TextArea(placeholder=question_ph, label=question_label)
130
  gr_api_key = gr.Text(placeholder=api_ph, label=api_label, type='password', visible=_need_api_key())
131
  with gr.Row():
132
+ gr_submit = gr.Button(submit_label, interactive=False)
133
  gr_next = gr.Button(next_label)
134
 
135
  with gr.Column():
 
141
 
142
 
143
  def _next_question(uuid_):
144
+ global count
145
  if not uuid_:
146
  uuid_ = str(uuid.uuid4())
147
+ count += 1
148
+ logging.info(f'Player {count} starts the game now')
149
  global _QUESTION_IDS
150
  _qid = _QUESTION_IDS.get(uuid_, -1)
151
  _qid += 1
 
153
 
154
  if _qid >= len(_QUESTIONS):
155
  del _QUESTION_IDS[uuid_]
156
+ logging.info(f'Player {count} has passed the game now')
157
  return game_cleared_label, '', '', {}, '', \
158
+ gr.Button(submit_label, interactive=False), \
159
  gr.Button(try_again_label, interactive=True), \
160
  ''
161
  else:
llmriddles/llms/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
- from .chatgpt import ask_chatgpt
2
  from .base import register_llm, get_llm_fn
 
 
 
 
1
  from .base import register_llm, get_llm_fn
2
+ from .chatgpt import ask_chatgpt
3
+ from .mistral import ask_mistral_7b_instruct
llmriddles/llms/llm_client.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import requests
3
+ import logging
4
+ import argparse
5
+
6
+
7
+ class LLMFlaskClient:
8
+ def __init__(self, ip: str, port: int, max_retry: int = 3):
9
+ self.ip = ip
10
+ self.port = port
11
+
12
+ self.url_prefix_format = 'http://{}:{}/'
13
+ self.url = self.url_prefix_format.format(self.ip, self.port)
14
+ self.max_retry = max_retry
15
+
16
+ self.logger = logging.getLogger()
17
+ self.logger.addHandler(logging.StreamHandler())
18
+ self.logger.handlers[0].setFormatter(logging.Formatter("%(message)s"))
19
+
20
+ def _request(self, name: str, data: dict):
21
+ for _ in range(self.max_retry):
22
+ try:
23
+ self.logger.info(f'{name}\ndata: {data}')
24
+ response = requests.post(self.url + name, json=data).json()
25
+ except Exception as e:
26
+ self.logger.warning('error: ', repr(e))
27
+ time.sleep(1)
28
+ continue
29
+ if response['code'] == 0:
30
+ return response['output']
31
+ else:
32
+ raise Exception(response['error_msg'])
33
+ raise Exception("Web service failed. Please retry or contact with manager")
34
+
35
+ def run(self, message: str) -> str:
36
+ try:
37
+ return self._request('ask_llm_for_answer', {'user_text': message})
38
+ except Exception as e:
39
+ return f"Error: {repr(e)}"
40
+
41
+
42
+ if __name__ == "__main__":
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--ip', required=True)
45
+ parser.add_argument('-p', '--port', required=True)
46
+ parser.add_argument('--debug', action='store_true')
47
+ args = parser.parse_args()
48
+ if args.debug:
49
+ logging.getLogger().setLevel(logging.INFO)
50
+ else:
51
+ logging.getLogger().setLevel(logging.WARNING)
52
+
53
+ client = LLMFlaskClient(args.ip, args.port)
54
+ print(client.run('Please concatenate string "1+" and "1=3". Only give me the result without "".'))
llmriddles/llms/llm_server.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from flask import Flask, request
3
+ import argparse
4
+ import logging
5
+
6
+
7
+ class LLMInstance:
8
+
9
+ def __init__(self, model_path: str, device: str = "cuda"):
10
+
11
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
13
+ self.model.to(device)
14
+ self.device = device
15
+
16
+ def query(self, message):
17
+ try:
18
+ messages = [
19
+ {"role": "user", "content": message},
20
+ ]
21
+ encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
22
+ model_inputs = encodeds.to(self.device)
23
+
24
+ generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
25
+ decoded = self.tokenizer.batch_decode(generated_ids)
26
+
27
+ # output is the string decoded[0] after "[/INST]". There may exist "</s>", delete it.
28
+ output = decoded[0].split("[/INST]")[1].split("</s>")[0]
29
+ return {
30
+ 'code': 0,
31
+ 'ret': True,
32
+ 'error_msg': None,
33
+ 'output': output
34
+ }
35
+ except Exception as e:
36
+ return {
37
+ 'code': 1,
38
+ 'ret': False,
39
+ 'error_msg': str(e),
40
+ 'output': None
41
+ }
42
+
43
+
44
+ def create_app(core):
45
+ app = Flask(__name__)
46
+
47
+ @app.route('/ask_llm_for_answer', methods=['POST'])
48
+ def ask_llm_for_answer():
49
+ user_text = request.json['user_text']
50
+ return core.query(user_text)
51
+
52
+ return app
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('-m', '--model_path', required=True, default='Mistral-7B-Instruct-v0.1', help='the model path of reward model')
58
+ parser.add_argument('--ip', default='0.0.0.0')
59
+ parser.add_argument('-p', '--port', default=8001)
60
+ parser.add_argument('--debug', action='store_true')
61
+ args = parser.parse_args()
62
+
63
+ if args.debug:
64
+ logging.getLogger().setLevel(logging.DEBUG)
65
+ else:
66
+ logging.getLogger().setLevel(logging.INFO)
67
+ logging.getLogger().addHandler(logging.StreamHandler())
68
+ logging.getLogger().handlers[0].setFormatter(logging.Formatter("%(message)s"))
69
+
70
+ core = LLMInstance(args.model_path)
71
+ app = create_app(core)
72
+ app.run(host=args.ip, port=args.port)
llmriddles/llms/mistral.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ from .base import register_llm
4
+ from .llm_client import LLMFlaskClient
5
+
6
+
7
+ @lru_cache()
8
+ def _get_mistral_7b_instruct_server(host: str, port: int):
9
+ from .llm_server import LLMInstance, create_app
10
+ core = LLMInstance('Mistral-7B-Instruct-v0.1')
11
+ app = create_app(core)
12
+ app.run(host=host, port=port)
13
+
14
+
15
+ def ask_mistral_7b_instruct(message: str, **kwargs):
16
+ host, port = '0.0.0.0', 8001
17
+ _get_mistral_7b_instruct_server(host, port)
18
+ client = LLMFlaskClient(host, port)
19
+ return client.run(message).strip()
20
+
21
+
22
+ register_llm('mistral-7b', ask_mistral_7b_instruct)
llmriddles/questions/__init__.py CHANGED
@@ -1,9 +1,11 @@
1
  from .executor import QuestionExecutor
2
  from .level1 import __file__ as _level1_file_
3
  from .level3 import __file__ as _level3_file_
 
4
  from .level5 import __file__ as _level5_file_
5
  from .question import Question, register_question, list_ordered_questions
6
 
7
  _ = _level1_file_
8
  _ = _level3_file_
 
9
  _ = _level5_file_
 
1
  from .executor import QuestionExecutor
2
  from .level1 import __file__ as _level1_file_
3
  from .level3 import __file__ as _level3_file_
4
+ from .level4 import __file__ as _level4_file_
5
  from .level5 import __file__ as _level5_file_
6
  from .question import Question, register_question, list_ordered_questions
7
 
8
  _ = _level1_file_
9
  _ = _level3_file_
10
+ _ = _level4_file_
11
  _ = _level5_file_
llmriddles/questions/level4.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .question import register_question
2
+ import re
3
+
4
+
5
+ def check_if_is_number(text: str):
6
+ try:
7
+ int(text)
8
+ return True
9
+ except ValueError:
10
+ return False
11
+
12
+
13
+ def get_all_numbers_in_a_sentence(text: str):
14
+ return [int(i) for i in re.findall(r'\d+', text)]
15
+
16
+
17
+ # CN_TEXT_1 = """
18
+ # 第四章第一题,请构造一个问题A,它的回复是不同于A的B,然后以B提问能再次得到A。
19
+
20
+ # 请在下面的输入框内填写你构造并点击按钮提交。
21
+ # """
22
+
23
+ # EN_TEXT_1 = """
24
+ # For the first question in chapter 4, please make a question A whose answer is B that is different from A, and then ask B to get A again.
25
+
26
+ # Please enter your query below and click the submit button
27
+ # """
28
+
29
+
30
+ # def _checker_1(question_text: str, user_text: str, answer_text: str, lang: str):
31
+ # _ = question_text, lang
32
+ # answer_text = answer_text.strip()
33
+ # user_text = user_text.strip()
34
+ # pass
35
+
36
+ # register_question({
37
+ # 'cn': CN_TEXT_1,
38
+ # 'en': EN_TEXT_1,
39
+ # }, _checker_1, level=4)
40
+
41
+
42
+ # CN_TEXT_2 = """
43
+ # 第四章第二题,
44
+
45
+ # 请在下面的输入框内填写你构造并点击按钮提交。
46
+ # """
47
+
48
+ # EN_TEXT_2 = """
49
+ # For the second question in chapter 4,
50
+
51
+ # Please enter your query below and click the submit button
52
+ # """
53
+
54
+
55
+ # def _checker_2(question_text: str, user_text: str, answer_text: str, lang: str):
56
+ # _ = question_text, lang
57
+ # answer_text = answer_text.strip()
58
+ # user_text = user_text.strip()
59
+ # pass
60
+
61
+ # register_question({
62
+ # 'cn': CN_TEXT_2,
63
+ # 'en': EN_TEXT_2,
64
+ # }, _checker_2, level=4)
65
+
66
+
67
+ CN_TEXT_3 = """
68
+ 第四章第三题(自然之密),请输入一个大于一的正整数作为问题,使回答里包含和它刚好相差1的数。
69
+
70
+ 请在下面的输入框内填写你构造并点击按钮提交。
71
+ """
72
+
73
+ EN_TEXT_3 = """
74
+ For the third question in chapter 4, please enter a positive integer greater than one as the question so that the answer contains a number that is exactly 1 different from it.
75
+
76
+ Please enter your query below and click the submit button
77
+ """
78
+
79
+
80
+ def _checker_3(question_text: str, user_text: str, answer_text: str, lang: str):
81
+ _ = question_text, lang
82
+ answer_text = answer_text.strip()
83
+ user_text = user_text.strip()
84
+ if not check_if_is_number(question_text):
85
+ return False, "问题应该是一个正整数" if lang == 'cn' else 'Question should be a positive integer.'
86
+ elif int(question_text) == 1:
87
+ return False, "问题应该是一个大于1的正整数" if lang == 'cn' else 'Question should be a positive integer greater than 1.'
88
+ elif int(question_text)-1 not in get_all_numbers_in_a_sentence(answer_text) or int(question_text)+1 not in get_all_numbers_in_a_sentence(answer_text):
89
+ return False, "回答中应该包含一个与问题相差1的数字" if lang == 'cn' else 'Answer should contain a number that is exactly 1 different from the question.'
90
+ else:
91
+ return True, None
92
+
93
+
94
+ register_question({
95
+ 'cn': CN_TEXT_3,
96
+ 'en': EN_TEXT_3,
97
+ }, _checker_3, level=4)
requirements.txt CHANGED
@@ -2,4 +2,6 @@ hbutils>=0.9.1
2
  tqdm
3
  requests>=2.20
4
  gradio==4.1.1
5
- openai>=1
 
 
 
2
  tqdm
3
  requests>=2.20
4
  gradio==4.1.1
5
+ openai>=1
6
+ flask
7
+ transformers