File size: 2,837 Bytes
fee0ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# SparkGPT.py
from . import SparkApi
#ไปฅไธ‹ๅฏ†้’ฅไฟกๆฏไปŽos็Žฏๅขƒ่Žทๅ–
import os

appid = os.environ['APPID']
api_secret = os.environ['APISecret'] 
api_key = os.environ['APIKey']

from .BaseLLM import BaseLLM

    


class SparkGPT(BaseLLM):

    def __init__(self, model="Spark3.0"):
        super(SparkGPT,self).__init__()
        self.model_type = model
        self.messages = []
        if self.model_type == "Spark2.0":
            self.domain = "generalv2"    # v2.0็‰ˆๆœฌ
            self.Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"  # v2.0็Žฏๅขƒ็š„ๅœฐๅ€
        elif self.model_type == "Spark1.5":
            self.domain = "general"   # v1.5็‰ˆๆœฌ
            self.Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat"  # v1.5็Žฏๅขƒ็š„ๅœฐๅ€
        elif self.model_type == "Spark3.0":
            self.domain = "generalv3"   # v3.0็‰ˆๆœฌ
            self.Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"  # v3.0็Žฏๅขƒ็š„ๅœฐๅ€
        else:
            raise Exception("Unknown Spark model")
    
    def initialize_message(self):
        self.messages = []

    def ai_message(self, payload):
        if len(self.messages) == 0:
            self.user_message("่ฏทๆ นๆฎๆˆ‘็š„่ฆๆฑ‚่ฟ›่กŒ่ง’่‰ฒๆ‰ฎๆผ”:")
        elif len(self.messages) % 2 == 1:
            self.messages.append({"role":"assistant","content":payload})
        elif len(self.messages)% 2 == 0:
            self.messages[-1]["content"] += "\n"+ payload

    def system_message(self, payload):
        
        self.messages.append({"role":"user","content":payload}) 
        

    def user_message(self, payload):
        if len(self.messages) % 2 == 0:
            self.messages.append({"role":"user","content":payload})
            # self.messages[-1]["content"] += 
        elif len(self.messages)% 2 == 1:
            self.messages[-1]["content"] += "\n"+ payload

    def get_response(self):
        # question = checklen(getText("user",Input))
        SparkApi.answer =""
        if self.model_type == "Spark2.0":
            self.domain = "generalv2"    # v2.0็‰ˆๆœฌ
            self.Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"  # v2.0็Žฏๅขƒ็š„ๅœฐๅ€
        elif self.model_type == "Spark1.5":
            self.domain = "general"   # v1.5็‰ˆๆœฌ
            self.Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat"  # v1.5็Žฏๅขƒ็š„ๅœฐๅ€
        elif self.model_type == "Spark3.0":
            self.domain = "generalv3"   # v3.0็‰ˆๆœฌ
            self.Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"  # v3.0็Žฏๅขƒ็š„ๅœฐๅ€
        else:
            raise Exception("Unknown Spark model")
        SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,self.messages)
        return SparkApi.answer
    
    def print_prompt(self):
        for message in self.messages:
            print(f"{message['role']}: {message['content']}")