File size: 2,356 Bytes
4962437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import requests
import os

class Anthropic:
    """Anthropic large language models."""

    def __init__(
            self, 
            model="claude-2", 
            max_tokens_to_sample=256, 
            temperature=None, 
            top_k=None, 
            top_p=None, 
            streaming=False, 
            default_request_timeout=None
        ):
        self.model = model
        self.max_tokens_to_sample = max_tokens_to_sample
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.streaming = streaming
        self.default_request_timeout = default_request_timeout or 600
        self.anthropic_api_url = os.getenv("ANTHROPIC_API_URL", "https://api.anthropic.com")
        self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")

    def _default_params(self):
        """Get the default parameters for calling Anthropic API."""
        d = {
            "max_tokens_to_sample": self.max_tokens_to_sample,
            "model": self.model,
        }
        if self.temperature is not None:
            d["temperature"] = self.temperature
        if self.top_k is not None:
            d["top_k"] = self.top_k
        if self.top_p is not None:
            d["top_p"] = self.top_p
        return d

    def generate(self, prompt, stop=None):
        """Call out to Anthropic's completion endpoint."""
        stop = stop or []
        params = self._default_params()
        headers = {"Authorization": f"Bearer {self.anthropic_api_key}"}
        data = {
            "prompt": prompt,
            "stop_sequences": stop,
            **params
        }
        response = requests.post(f"{self.anthropic_api_url}/completions", headers=headers, json=data, timeout=self.default_request_timeout)
        return response.json().get("completion")
    
    def __call__(self, prompt, stop=None):
        """Call out to Anthropic's completion endpoint."""
        stop = stop or []
        params = self._default_params()
        headers = {"Authorization": f"Bearer {self.anthropic_api_key}"}
        data = {
            "prompt": prompt,
            "stop_sequences": stop,
            **params
        }
        response = requests.post(f"{self.anthropic_api_url}/completions", headers=headers, json=data, timeout=self.default_request_timeout)
        return response.json().get("completion")