File size: 2,627 Bytes
312035b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9225b05
 
312035b
 
 
9d4ba25
312035b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9225b05
312035b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os


def process_config(config):
    if "HF_TOKEN" not in os.environ:
        raise ValueError("Please set the HF_TOKEN environment variable.")
    client_kwargs = {}
    if "client_llm" in config:
        if "model_id" in config["client_llm"]:
            client_kwargs["model_id"] = config["client_llm"]["model_id"]
        else:
            raise ValueError("config.yaml is missing client model_id.")
        if "url" in config["client_llm"]:
            client_kwargs["inference_server_url"] = config["client_llm"]["url"]
        else:
            raise ValueError("config.yaml is missing client url.")
        client_kwargs["api_key"] = os.getenv("HF_TOKEN")
        client_kwargs["llm_backend"] = "HFChat"
        client_kwargs["temperature"] = config["client_llm"].get("temperature",.6)
        client_kwargs["max_tokens"] = config["client_llm"].get("max_tokens",800)
    else:
        raise ValueError("config.yaml is missing client_llm settings.")
    
    guide_kwargs = {"classifier_kwargs": {}}
    if "expert_llm" in config:
        if "model_id" in config["expert_llm"]:
            guide_kwargs["expert_model"] = config["expert_llm"]["model_id"]
        else:
            raise ValueError("config.yaml is missing expert model_id.")
        if "url" in config["expert_llm"]:
            guide_kwargs["inference_server_url"] = config["expert_llm"]["url"]
        else:
            raise ValueError("config.yaml is missing expert url.")
        guide_kwargs["api_key"] = os.getenv("HF_TOKEN")
        guide_kwargs["llm_backend"] = "HFChat"
    else:
        raise ValueError("config.yaml is missing expert_llm settings.")

    if "classifier_llm" in config:
        if "model_id" in config["classifier_llm"]:
            guide_kwargs["classifier_kwargs"]["model_id"] = config["classifier_llm"]["model_id"]
        else:
            raise ValueError("config.yaml is missing classifier model_id.")
        if "url" in config["classifier_llm"]:
            guide_kwargs["classifier_kwargs"]["inference_server_url"] = config["classifier_llm"]["url"]
        else:
            raise ValueError("config.yaml is missing classifier url.")
        if "batch_size" in config["classifier_llm"]:
            guide_kwargs["classifier_kwargs"]["batch_size"] = int(config["classifier_llm"]["batch_size"])
        else:
            raise ValueError("config.yaml is missing classifier batch_size.")
        guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN")  # classifier api key
    else:
        raise ValueError("config.yaml is missing classifier_llm settings.")

    return client_kwargs, guide_kwargs