File size: 4,040 Bytes
fb8c051
 
 
 
a0d1776
fb8c051
 
 
 
 
 
 
 
 
 
 
 
 
 
a0d1776
fb8c051
 
a0d1776
fb8c051
 
 
 
 
 
 
 
 
a0d1776
fb8c051
 
 
 
 
 
 
 
 
 
238735e
 
 
 
 
fb8c051
 
 
238735e
 
 
fb8c051
 
 
a0d1776
1b82d4c
 
 
 
 
fb8c051
 
 
 
 
 
 
 
c9efba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb8c051
 
 
 
 
 
 
 
 
 
a0d1776
 
fb8c051
 
 
 
 
 
 
 
 
a0d1776
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import openai
import re
import json
import logging

log = logging.getLogger(__name__)


def extract_responses(assistant_message):
    # pattern = re.compile(r"f\.write\(r'{1,3}(.*?)'{0,3}\){0,1}$", re.DOTALL)
    pattern = re.compile(r"f\.write\(r['\"]{1,3}(.*?)['\"]{0,3}\){0,1}$", re.DOTALL)
    match = re.search(pattern, assistant_message)
    if match:
        return match.group(1)
    else:
        log.info("Responses are not put in Python codes. Directly return assistant_message.\n")
        log.info(f"assistant_message: {assistant_message}")
        return assistant_message


def extract_keywords(assistant_message, default_keywords=None):
    if default_keywords is None:
        default_keywords = {"machine learning": 5}

    try:
        keywords = json.loads(assistant_message)
    except ValueError:
        log.info("Responses are not in json format. Return the default dictionary.\n ")
        log.info(f"assistant_message: {assistant_message}")
        return default_keywords
    return keywords


def extract_section_name(assistant_message, default_section_name=""):
    try:
        keywords = json.loads(assistant_message)
    except ValueError:
        log.info("Responses are not in json format. Return None.\n ")
        log.info(f"assistant_message: {assistant_message}")
        return default_section_name
    return keywords


def extract_json(assistant_message, default_output=None):
    if default_output is None:
        default_keys = ["Method 1", "Method 2"]
    else:
        default_keys = default_output
    try:
        dict = json.loads(assistant_message)
    except:
        log.info("Responses are not in json format. Return the default keys.\n ")
        log.info(f"assistant_message: {assistant_message}")
        return default_keys
    return dict.keys()


def get_responses(user_message, model="gpt-4", temperature=0.4, openai_key=None):
    if openai.api_key is None and openai_key is None:
        raise ValueError("OpenAI API key must be provided.")
    if openai_key is not None:
        openai.api_key = openai_key

    conversation_history = [
        {"role": "system", "content": "You are an assistant in writing machine learning papers."}
    ]
    conversation_history.append({"role": "user", "content": user_message})
    response = openai.ChatCompletion.create(
        model=model,
        messages=conversation_history,
        n=1,  # Number of responses you want to generate
        temperature=temperature,  # Controls the creativity of the generated response
    )
    assistant_message = response['choices'][0]["message"]["content"]
    usage = response['usage']
    log.info(assistant_message)
    return assistant_message, usage

def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
    conversation_history = [
        {"role": "system", "content": systems},
        {"role": "user", "content": prompts}
    ]
    response = openai.ChatCompletion.create(
        model=model,
        messages=conversation_history,
        n=1,  # Number of responses you want to generate
        temperature=temperature,  # Controls the creativity of the generated response
    )
    assistant_message = response['choices'][0]["message"]["content"]
    usage = response['usage']
    log.info(assistant_message)
    return assistant_message, usage


if __name__ == "__main__":
    test_strings = [r"f.write(r'hello world')", r"f.write(r'''hello world''')", r"f.write(r'''hello world",
                    r"f.write(r'''hello world'", r'f.write(r"hello world")', r'f.write(r"""hello world""")',
                    r'f.write(r"""hello world"', r'f.write(r"""hello world']
    for input_string in test_strings:
        print("input_string: ", input_string)
        pattern = re.compile(r"f\.write\(r['\"]{1,3}(.*?)['\"]{0,3}\){0,1}$", re.DOTALL)

        match = re.search(pattern, input_string)
        if match:
            extracted_string = match.group(1)
            print("Extracted string:", extracted_string)
        else:
            print("No match found")