File size: 2,518 Bytes
54b3256
 
 
 
 
5117e0a
54b3256
 
 
 
 
 
 
 
 
 
 
 
5117e0a
54b3256
 
 
5117e0a
cc9a95f
 
54b3256
5117e0a
54b3256
 
cc9a95f
 
 
 
 
 
 
 
54b3256
 
 
 
 
 
 
 
 
 
 
 
 
5117e0a
54b3256
 
 
5117e0a
 
54b3256
 
 
 
 
 
 
 
 
 
 
 
1215037
54b3256
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import os
import requests
from typing import List

from src.common import config_dir, hf_api_token


class HFLlamaChatModel:
    models = None

    @classmethod
    def load_configs(cls):
        config_file = os.path.join(config_dir, "models.json")
        with open(config_file, "r") as f:
            configs = json.load(f)['models']
            cls.models = []
            for cfg in configs:
                if cls.for_name(cfg['name']) is None:
                    cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))

    @classmethod
    def for_name(cls, name: str):
        if cls.models is None:
            cls.load_configs()
        for m in cls.models:
            if m.name == name:
                return m

    @classmethod
    def for_model(cls, model: str):
        if cls.models is None:
            cls.load_configs()
        for m in cls.models:
            if m.id == model:
                return m

    @classmethod
    def available_models(cls) -> List[str]:
        if cls.models is None:
            cls.load_configs()
        return [m.name for m in cls.models]

    def __init__(self, name: str, id: str, description: str):
        self.name = name
        self.id = id
        self.description = description

    def __call__(self,
                 query: str,
                 auth_token: str = None,
                 system_prompt: str = None,
                 max_new_tokens: str = 256,
                 temperature: float = 1.0):
        if auth_token is None:
            auth_token = hf_api_token()  # Attempt look up if not provided
        headers = {"Authorization": f"Bearer {auth_token}"}
        api_url = f"https://api-inference.huggingface.co/models/{self.id}"
        if system_prompt is None:
            system_prompt = "You are a helpful assistant."
        query_input = f"[INST] <<SYS>> {system_prompt} <<SYS>> {query} [/INST] "
        query_payload = {
            "inputs": query_input,
            "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}
        }
        response = requests.post(api_url, headers=headers, json=query_payload)
        if response.status_code == 200:
            resp_json = json.loads(response.text)
            llm_text = resp_json[0]['generated_text'].strip()
            return llm_text
        else:
            error_detail = f"Error from hugging face code: {response.status_code}: {response.reason} ({response.content})"
            raise ValueError(error_detail)