guir-chat / backend /config.py
Gregor Betz
config bugfix
9225b05 unverified
raw
history blame
2.63 kB
import os
def process_config(config):
if "HF_TOKEN" not in os.environ:
raise ValueError("Please set the HF_TOKEN environment variable.")
client_kwargs = {}
if "client_llm" in config:
if "model_id" in config["client_llm"]:
client_kwargs["model_id"] = config["client_llm"]["model_id"]
else:
raise ValueError("config.yaml is missing client model_id.")
if "url" in config["client_llm"]:
client_kwargs["inference_server_url"] = config["client_llm"]["url"]
else:
raise ValueError("config.yaml is missing client url.")
client_kwargs["api_key"] = os.getenv("HF_TOKEN")
client_kwargs["llm_backend"] = "HFChat"
client_kwargs["temperature"] = config["client_llm"].get("temperature",.6)
client_kwargs["max_tokens"] = config["client_llm"].get("max_tokens",800)
else:
raise ValueError("config.yaml is missing client_llm settings.")
guide_kwargs = {"classifier_kwargs": {}}
if "expert_llm" in config:
if "model_id" in config["expert_llm"]:
guide_kwargs["expert_model"] = config["expert_llm"]["model_id"]
else:
raise ValueError("config.yaml is missing expert model_id.")
if "url" in config["expert_llm"]:
guide_kwargs["inference_server_url"] = config["expert_llm"]["url"]
else:
raise ValueError("config.yaml is missing expert url.")
guide_kwargs["api_key"] = os.getenv("HF_TOKEN")
guide_kwargs["llm_backend"] = "HFChat"
else:
raise ValueError("config.yaml is missing expert_llm settings.")
if "classifier_llm" in config:
if "model_id" in config["classifier_llm"]:
guide_kwargs["classifier_kwargs"]["model_id"] = config["classifier_llm"]["model_id"]
else:
raise ValueError("config.yaml is missing classifier model_id.")
if "url" in config["classifier_llm"]:
guide_kwargs["classifier_kwargs"]["inference_server_url"] = config["classifier_llm"]["url"]
else:
raise ValueError("config.yaml is missing classifier url.")
if "batch_size" in config["classifier_llm"]:
guide_kwargs["classifier_kwargs"]["batch_size"] = int(config["classifier_llm"]["batch_size"])
else:
raise ValueError("config.yaml is missing classifier batch_size.")
guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
else:
raise ValueError("config.yaml is missing classifier_llm settings.")
return client_kwargs, guide_kwargs