File size: 3,839 Bytes
003d053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
try:
    import openai
except ImportError:
    print("The 'openai' module is not installed. Please install it using 'pip install openai'.")
    exit(1)
import json
import re
import time
from tqdm import tqdm
from config import LLM_RETRIES, LLM_REQUEST_INTERVAL, LLM_RETRY_DELAY, LLM_MAX_TEXT_LENGTH, LLM_PROMPT


def send_request(client, prompt, text, model):
    text = remove_json_escape_characters(text)
    messages = [{"role": "user", "content": f"{prompt}\n\n{text}"}]
    try:
        response = client.chat.completions.create(model=model, messages=messages, max_tokens=4096)
        print(response)
        return response.choices[0].message.content
    except openai.OpenAIError as e:
        print(f"OpenAI API error: {e}")
        return None


def clean_text(text):
    import re
    if isinstance(text, str):
        # 移除 ASCII 控制字符(0-31 和 127)
        text = re.sub(r'[\x00-\x1F\x7F]', '', text)
    return text


def extract_json(response_text):
    with open("debug.txt", "w", encoding="utf8") as f:
        f.write(response_text)
    pattern = re.compile(r'((\[[^\}]{3,})?\{s*[^\}\{]{3,}?:.*\}([^\{]+\])?)', re.M | re.S)
    match = re.search(pattern, response_text)
    if match:
        return match.group(0)
    return None


def clean_and_load_json(json_string):
    try:
        cleaned_json_string = json_string.replace("'", '"')
        cleaned_json_string = clean_text(cleaned_json_string)
        # debug 写入文本
        with open("debug.json", "w", encoding="utf8") as f:
            f.write(cleaned_json_string)
        json_obj = json.loads(cleaned_json_string)
        return json_obj
    except json.JSONDecodeError as e:
        print(f"JSON decode error: {e}")
        return None


def validate_json(json_obj, required_keys):
    return isinstance(json_obj, list)
    print(json_obj)
    return True
    if json_obj and all(key in json_obj for key in required_keys):
        return True
    return False


def process_text(client, prompt, text, model, required_keys):
    parts = [text[i:i + LLM_MAX_TEXT_LENGTH] for i in range(0, len(text), LLM_MAX_TEXT_LENGTH)]
    results = []

    for part in tqdm(parts, desc="Processing text"):
        for attempt in range(LLM_RETRIES + 1):
            response = send_request(client, prompt, part, model)
            if response:
                json_string = extract_json(response)
                if json_string:
                    json_obj = clean_and_load_json(json_string)
                    if validate_json(json_obj, required_keys):
                        results.extend(json_obj)
                        break
                    else:
                        print(f"Invalid JSON structure. Retrying ({attempt + 1}/{LLM_RETRIES})...")
                else:
                    print(f"No JSON found in response. Retrying ({attempt + 1}/{LLM_RETRIES})...")
            else:
                print(f"API request failed. Retrying ({attempt + 1}/{LLM_RETRIES})...")
            time.sleep(LLM_RETRY_DELAY)
        time.sleep(LLM_REQUEST_INTERVAL)

    return results


def llm_operation(api_base, api_key, model, prompt, text, required_keys):
    client = openai.OpenAI(api_key=api_key, base_url=api_base)
    return process_text(client, prompt, text, model, required_keys)


def remove_json_escape_characters(s):
    """
    移除用户提交文本中容易被llm输出导致json校验出错的字符
    :param s:
    :return:
    """
    # 定义需要移除的字符
    escape_chars = {
        '"': '',
        '\\': '',
        '/': '',
        '\b': '',
        '\f': '',
        '\n': '',
        '\r': '',
        '\t': '',
    }
    escape_re = re.compile('|'.join(re.escape(key) for key in escape_chars.keys()))

    def replace(match):
        return escape_chars[match.group(0)]

    return escape_re.sub(replace, s)