File size: 4,340 Bytes
1ea2ba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
from ..presets import *
from ..utils import *

from .base_model import BaseLLMModel


class Claude_Client(BaseLLMModel):
    def __init__(self, model_name, api_secret) -> None:
        super().__init__(model_name=model_name)
        self.api_secret = api_secret
        if None in [self.api_secret]:
            raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
        self.claude_client = Anthropic(api_key=self.api_secret)

    def _get_claude_style_history(self):
        history = []
        image_buffer = []
        image_count = 0
        for message in self.history:
            if message["role"] == "user":
                content = []
                if image_buffer:
                    if image_count == 1:
                        content.append(
                            {
                                "type": "image",
                                "source": {
                                    "type": "base64",
                                    "media_type": f"image/{self.get_image_type(image_buffer[0])}",
                                    "data": self.get_base64_image(image_buffer[0]),
                                },
                            },
                        )
                    else:
                        image_buffer_length = len(image_buffer)
                        for idx, image in enumerate(image_buffer):
                            content.append(
                                {"type": "text", "text": f"Image {image_count - image_buffer_length + idx + 1}:"},
                            )
                            content.append(
                                {
                                    "type": "image",
                                    "source": {
                                        "type": "base64",
                                        "media_type": f"image/{self.get_image_type(image)}",
                                        "data": self.get_base64_image(image),
                                    },
                                },
                            )
                if content:
                    content.append({"type": "text", "text": message["content"]})
                    history.append(construct_user(content))
                    image_buffer = []
                else:
                    history.append(message)
            elif message["role"] == "assistant":
                history.append(message)
            elif message["role"] == "image":
                image_buffer.append(message["content"])
                image_count += 1
        # history with base64 data replaced with "#base64#"
        # history_for_display = history.copy()
        # for message in history_for_display:
        #     if message["role"] == "user":
        #         if type(message["content"]) == list:
        #             for content in message["content"]:
        #                 if content["type"] == "image":
        #                     content["source"]["data"] = "#base64#"
        # logging.info(f"History for Claude: {history_for_display}")
        return history

    def get_answer_stream_iter(self):
        system_prompt = self.system_prompt
        history = self._get_claude_style_history()

        try:
            with self.claude_client.messages.stream(
                model=self.model_name,
                max_tokens=self.max_generation_token,
                messages=history,
                system=system_prompt,
            ) as stream:
                partial_text = ""
                for text in stream.text_stream:
                    partial_text += text
                    yield partial_text
        except Exception as e:
            yield i18n(GENERAL_ERROR_MSG) + ": " + str(e)

    def get_answer_at_once(self):
        system_prompt = self.system_prompt
        history = self._get_claude_style_history()

        response = self.claude_client.messages.create(
            model=self.model_name,
            max_tokens=self.max_generation_token,
            messages=history,
            system=system_prompt,
        )
        if response is not None:
            return response.content[0].text, response.usage.output_tokens
        else:
            return i18n("获取资源错误"), 0