File size: 2,497 Bytes
d098bff
 
 
c5b0047
 
d098bff
c5b0047
4fe68bc
 
 
 
c5b0047
 
d098bff
 
 
 
 
 
 
 
 
 
 
c5b0047
 
 
 
 
 
 
d098bff
c5b0047
 
d098bff
 
c5b0047
 
 
 
d098bff
c5b0047
 
 
 
 
 
 
 
 
 
 
 
4fe68bc
 
 
 
91fbea0
4fe68bc
 
2e517b5
4fe68bc
 
 
 
 
 
 
 
 
5b981d0
4fe68bc
 
 
 
 
5b981d0
4fe68bc
 
 
 
 
 
 
 
 
 
4bb483c
 
 
 
 
 
 
 
 
 
4fe68bc
0af8d1d
2e517b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from dataclasses import dataclass
from enum import IntEnum

import yaml

from typing import Dict, Optional, List
from pydantic import BaseModel, ValidationError
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

from openai import OpenAI


class OAuthProvider(IntEnum):
    NONE = 0
    GOOGLE = 1


@dataclass
class User:
    oauth: OAuthProvider
    username: str
    permissions_id: str


class PileConfig(BaseModel):
    file2persona: Dict[str, str]
    file2prefix: Dict[str, str]
    persona2system: Dict[str, str]
    prompt: str


class InferenceConfig(BaseModel):
    chat_template: str
    permissions: Dict[str, list] = {}


class RepoConfig(BaseModel):
    name: str


class ModelConfig(BaseModel):
    pile: PileConfig
    inference: InferenceConfig
    repo: RepoConfig

    @classmethod
    def from_yaml(cls, yaml_file = "datasets/config.yaml"):
        with open(yaml_file, 'r') as file:
            data = yaml.safe_load(file)
        try:
            return cls(**data)
        except ValidationError as e:
            raise e
        

class Client:
    def __init__(self, api_url, api_key, personas = {}):
        self.api_url = api_url
        self.api_key = api_key
        self.input_personas = personas
        
        self.init_all()

    def init_all(self):
        self.init_client()
        self.get_metadata()
        self.get_personas()

    def init_client(self):
        self.openai = OpenAI(
            base_url=f"{self.api_url}/v1",
            api_key=self.api_key,
        )

    def get_metadata(self):
        models = self.openai.models.list()
        vllm_model_name = models.data[0].id

        model_name, *suffix = vllm_model_name.split("@")
        revision = dict(enumerate(suffix)).get(0, None)

        self.vllm_model_name = vllm_model_name
        self.model_name = model_name
        self.revision = revision
    
    def get_personas(self):
        personas = {}
        if self.revision is not None:
            try:
                config_path = hf_hub_download(self.model_name, "config.yaml",
                                        subfolder="datasets",
                                        revision=self.revision)
                self.config = ModelConfig.from_yaml(config_path)
                personas = self.config.pile.persona2system
            except EntryNotFoundError:
                pass

        personas["vanilla"] = None
        self.personas = self.input_personas | personas