Spaces:
Runtime error
Runtime error
File size: 2,518 Bytes
54b3256 5117e0a 54b3256 5117e0a 54b3256 5117e0a cc9a95f 54b3256 5117e0a 54b3256 cc9a95f 54b3256 5117e0a 54b3256 5117e0a 54b3256 1215037 54b3256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import json
import os
import requests
from typing import List
from src.common import config_dir, hf_api_token
class HFLlamaChatModel:
models = None
@classmethod
def load_configs(cls):
config_file = os.path.join(config_dir, "models.json")
with open(config_file, "r") as f:
configs = json.load(f)['models']
cls.models = []
for cfg in configs:
if cls.for_name(cfg['name']) is None:
cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
@classmethod
def for_name(cls, name: str):
if cls.models is None:
cls.load_configs()
for m in cls.models:
if m.name == name:
return m
@classmethod
def for_model(cls, model: str):
if cls.models is None:
cls.load_configs()
for m in cls.models:
if m.id == model:
return m
@classmethod
def available_models(cls) -> List[str]:
if cls.models is None:
cls.load_configs()
return [m.name for m in cls.models]
def __init__(self, name: str, id: str, description: str):
self.name = name
self.id = id
self.description = description
def __call__(self,
query: str,
auth_token: str = None,
system_prompt: str = None,
max_new_tokens: str = 256,
temperature: float = 1.0):
if auth_token is None:
auth_token = hf_api_token() # Attempt look up if not provided
headers = {"Authorization": f"Bearer {auth_token}"}
api_url = f"https://api-inference.huggingface.co/models/{self.id}"
if system_prompt is None:
system_prompt = "You are a helpful assistant."
query_input = f"[INST] <<SYS>> {system_prompt} <<SYS>> {query} [/INST] "
query_payload = {
"inputs": query_input,
"parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}
}
response = requests.post(api_url, headers=headers, json=query_payload)
if response.status_code == 200:
resp_json = json.loads(response.text)
llm_text = resp_json[0]['generated_text'].strip()
return llm_text
else:
error_detail = f"Error from hugging face code: {response.status_code}: {response.reason} ({response.content})"
raise ValueError(error_detail)
|