NeonLLM / shared.py
neondaniel's picture
Add OAuth Support (#1)
d098bff verified
raw
history blame contribute delete
No virus
2.51 kB
from dataclasses import dataclass
from enum import IntEnum
import yaml
from typing import Dict, Optional, List
from pydantic import BaseModel, ValidationError
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from openai import OpenAI
class OAuthProvider(IntEnum):
NONE = 0
GOOGLE = 1
@dataclass
class User:
oauth: OAuthProvider
username: str
permissions_id: str
class PileConfig(BaseModel):
file2persona: Dict[str, str]
file2prefix: Dict[str, str]
persona2system: Dict[str, str]
prompt: str
class InferenceConfig(BaseModel):
chat_template: str
permissions: Dict[str, list] = {}
class RepoConfig(BaseModel):
name: str
tag: str
class ModelConfig(BaseModel):
pile: PileConfig
inference: InferenceConfig
repo: RepoConfig
@classmethod
def from_yaml(cls, yaml_file = "datasets/config.yaml"):
with open(yaml_file, 'r') as file:
data = yaml.safe_load(file)
try:
return cls(**data)
except ValidationError as e:
raise e
class Client:
def __init__(self, api_url, api_key, personas = {}):
self.api_url = api_url
self.api_key = api_key
self.input_personas = personas
self.init_all()
def init_all(self):
self.init_client()
self.get_metadata()
self.get_personas()
def init_client(self):
self.openai = OpenAI(
base_url=f"{self.api_url}/v1",
api_key=self.api_key,
)
def get_metadata(self):
models = self.openai.models.list()
vllm_model_name = models.data[0].id
model_name, *suffix = vllm_model_name.split("@")
revision = dict(enumerate(suffix)).get(0, None)
self.vllm_model_name = vllm_model_name
self.model_name = model_name
self.revision = revision
def get_personas(self):
personas = {}
if self.revision is not None:
try:
config_path = hf_hub_download(self.model_name, "config.yaml",
subfolder="datasets",
revision=self.revision)
self.config = ModelConfig.from_yaml(config_path)
personas = self.config.pile.persona2system
except EntryNotFoundError:
pass
personas["vanilla"] = None
self.personas = self.input_personas | personas