sc_ma commited on
Commit
1012e47
1 Parent(s): 49d990e
Files changed (1) hide show
  1. utils/gpt_interaction.py +74 -103
utils/gpt_interaction.py CHANGED
@@ -1,18 +1,70 @@
1
- import os
2
- import time
3
-
4
  import openai
5
- import logging
6
- import requests
7
  import json
8
-
9
  log = logging.getLogger(__name__)
10
 
11
- def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  conversation_history = [
13
- {"role": "system", "content": systems},
14
- {"role": "user", "content": prompts}
15
  ]
 
16
  response = openai.ChatCompletion.create(
17
  model=model,
18
  messages=conversation_history,
@@ -25,98 +77,17 @@ def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
25
  return assistant_message, usage
26
 
27
 
28
- class GPTModel_API2D_SUPPORT:
29
- def __init__(self, model="gpt-4", temperature=0, presence_penalty=0,
30
- frequency_penalty=0, url=None, key=None, max_attempts=1, delay=20):
31
- if url is None:
32
- url = "https://api.openai.com/v1/chat/completions"
33
- if key is None:
34
- key = os.getenv("OPENAI_API_KEY")
35
-
36
- self.model = model
37
- self.temperature = temperature
38
- self.url = url
39
- self.key = key
40
- self.presence_penalty = presence_penalty
41
- self.frequency_penalty = frequency_penalty
42
- self.max_attempts = max_attempts
43
- self.delay = delay
44
-
45
- def __call__(self, systems, prompts, return_json=False):
46
- headers = {
47
- "Content-Type": "application/json",
48
- "Authorization": f"Bearer {self.key}",
49
- }
50
-
51
- data = {
52
- "model": f"{self.model}",
53
- "messages": [
54
- {"role": "system", "content": systems},
55
- {"role": "user", "content": prompts}],
56
- "temperature": self.temperature,
57
- "n": 1,
58
- "stream": False,
59
- "presence_penalty": self.presence_penalty,
60
- "frequency_penalty": self.frequency_penalty
61
- }
62
- for _ in range(self.max_attempts):
63
- try:
64
- # todo: in some cases, UnicodeEncodeError is raised:
65
- # 'gbk' codec can't encode character '\xdf' in position 1898: illegal multibyte sequence
66
- response = requests.post(self.url, headers=headers, data=json.dumps(data))
67
- response = response.json()
68
- assistant_message = response['choices'][0]["message"]["content"]
69
- usage = response['usage']
70
- log.info(assistant_message)
71
- if return_json:
72
- assistant_message = json.loads(assistant_message)
73
- return assistant_message, usage
74
- except Exception as e:
75
- print(f"Failed to get response. Error: {e}")
76
- time.sleep(self.delay)
77
- raise RuntimeError("Failed to get response from OpenAI.")
78
-
79
-
80
- class GPTModel:
81
- def __init__(self, model="gpt-4", temperature=0.9, presence_penalty=0,
82
- frequency_penalty=0, max_attempts=1, delay=20):
83
- self.model = model
84
- self.temperature = temperature
85
- self.presence_penalty = presence_penalty
86
- self.frequency_penalty = frequency_penalty
87
- self.max_attempts = max_attempts
88
- self.delay = delay
89
-
90
- def __call__(self, systems, prompts, return_json=False):
91
- conversation_history = [
92
- {"role": "system", "content": systems},
93
- {"role": "user", "content": prompts}
94
- ]
95
- for _ in range(self.max_attempts):
96
- try:
97
- response = openai.ChatCompletion.create(
98
- model=self.model,
99
- messages=conversation_history,
100
- n=1,
101
- temperature=self.temperature,
102
- presence_penalty=self.presence_penalty,
103
- frequency_penalty=self.frequency_penalty,
104
- stream=False
105
- )
106
- assistant_message = response['choices'][0]["message"]["content"]
107
- usage = response['usage']
108
- log.info(assistant_message)
109
- if return_json:
110
- assistant_message = json.loads(assistant_message)
111
- return assistant_message, usage
112
- except Exception as e:
113
- print(f"Failed to get response. Error: {e}")
114
- time.sleep(self.delay)
115
- raise RuntimeError("Failed to get response from OpenAI.")
116
-
117
-
118
-
119
  if __name__ == "__main__":
120
- bot = GPTModel()
121
- r = bot("You are an assistant.", "Hello.")
122
- print(r)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import openai
2
+ import re
3
+ import os
4
  import json
5
+ import logging
6
  log = logging.getLogger(__name__)
7
 
8
+ # todo: 将api_key通过函数传入; 需要改很多地方
9
+ # openai.api_key = os.environ['OPENAI_API_KEY']
10
+
11
+ def extract_responses(assistant_message):
12
+ # pattern = re.compile(r"f\.write\(r'{1,3}(.*?)'{0,3}\){0,1}$", re.DOTALL)
13
+ pattern = re.compile(r"f\.write\(r['\"]{1,3}(.*?)['\"]{0,3}\){0,1}$", re.DOTALL)
14
+ match = re.search(pattern, assistant_message)
15
+ if match:
16
+ return match.group(1)
17
+ else:
18
+ log.info("Responses are not put in Python codes. Directly return assistant_message.\n")
19
+ log.info(f"assistant_message: {assistant_message}")
20
+ return assistant_message
21
+
22
+ def extract_keywords(assistant_message, default_keywords=None):
23
+ if default_keywords is None:
24
+ default_keywords = {"machine learning":5}
25
+
26
+ try:
27
+ keywords = json.loads(assistant_message)
28
+ except ValueError:
29
+ log.info("Responses are not in json format. Return the default dictionary.\n ")
30
+ log.info(f"assistant_message: {assistant_message}")
31
+ return default_keywords
32
+ return keywords
33
+
34
+ def extract_section_name(assistant_message, default_section_name=""):
35
+ try:
36
+ keywords = json.loads(assistant_message)
37
+ except ValueError:
38
+ log.info("Responses are not in json format. Return None.\n ")
39
+ log.info(f"assistant_message: {assistant_message}")
40
+ return default_section_name
41
+ return keywords
42
+
43
+
44
+ def extract_json(assistant_message, default_output=None):
45
+ if default_output is None:
46
+ default_keys = ["Method 1", "Method 2"]
47
+ else:
48
+ default_keys = default_output
49
+ try:
50
+ dict = json.loads(assistant_message)
51
+ except:
52
+ log.info("Responses are not in json format. Return the default keys.\n ")
53
+ log.info(f"assistant_message: {assistant_message}")
54
+ return default_keys
55
+ return dict.keys()
56
+
57
+
58
+ def get_responses(user_message, model="gpt-4", temperature=0.4, openai_key = None):
59
+ if openai.api_key is None and openai_key is None:
60
+ raise ValueError("OpenAI API key must be provided.")
61
+ if openai_key is not None:
62
+ openai.api_key = openai_key
63
+
64
  conversation_history = [
65
+ {"role": "system", "content": "You are an assistant in writing machine learning papers."}
 
66
  ]
67
+ conversation_history.append({"role": "user", "content": user_message})
68
  response = openai.ChatCompletion.create(
69
  model=model,
70
  messages=conversation_history,
 
77
  return assistant_message, usage
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if __name__ == "__main__":
81
+ test_strings = [r"f.write(r'hello world')", r"f.write(r'''hello world''')", r"f.write(r'''hello world",
82
+ r"f.write(r'''hello world'", r'f.write(r"hello world")', r'f.write(r"""hello world""")',
83
+ r'f.write(r"""hello world"', r'f.write(r"""hello world']
84
+ for input_string in test_strings:
85
+ print("input_string: ", input_string)
86
+ pattern = re.compile(r"f\.write\(r['\"]{1,3}(.*?)['\"]{0,3}\){0,1}$", re.DOTALL)
87
+
88
+ match = re.search(pattern, input_string)
89
+ if match:
90
+ extracted_string = match.group(1)
91
+ print("Extracted string:", extracted_string)
92
+ else:
93
+ print("No match found")