File size: 1,809 Bytes
1086ffd
 
c6ad667
1086ffd
 
c6ad667
1086ffd
 
 
 
fe413ad
 
 
1086ffd
 
 
 
 
 
 
 
 
 
c6ad667
 
 
 
 
 
 
 
1086ffd
 
c6ad667
9078d0a
 
c6ad667
1086ffd
fe413ad
 
 
 
 
 
 
1086ffd
 
fe413ad
1086ffd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import openai
import sys
import re
sys.path.append('.')
from local_config import openai_key
from utils.format.txt_2_list import txt_2_list

# Set up your API key
openai.api_key = openai_key

def text_classification(src_txt, type_arr, history=[]):
    history_txt = ''.join([f'输入|```{q}```输出|{a}\n' for q, a in history])
    user = f"你是一个聪明而且有百年经验的文本分类器. 你的任务是从一段文本里面提取出相应的分类结果签。你的回答必须用统一的格式。文本用```符号分割。分类类型保存在一个数组里{type_arr}\n{history_txt}输入|```{src_txt}```输出|"
    # Call the OpenAI API
    completion = openai.ChatCompletion.create(
                    model="gpt-3.5-turbo",
                    messages=[
                        {"role": "user", "content": f"{user}"},
                    ]
                )

    # Extract the output and parse the JSON array
    content = completion.choices[0].message.content
    # Check out in type_arr
    result = []
    for type in type_arr:
        if type in content:
            result.append(type)
            # 删去已经匹配的type
            content = content.replace(type, '')
    return result

if __name__ == '__main__':
    # type_arr = ['好评', '差评']
    # type_arr_txt = "是差评、不是差评"
    type_arr_txt = "天气查询、股票查询、其他"
    type_arr = txt_2_list(type_arr_txt)
    txts = [
        '这个商品真不错',
        '用着不行',
        '没用过这么好的东西',
        # '今天天气怎么样',
    ]
    history = [
        ['这个商品真不错', ['其他']],
    ]
    for txt in txts:
        result = text_classification(txt, type_arr, history)
        print(txt, result)