mtyrrell commited on
Commit
7fca207
·
1 Parent(s): bea4d82

added novita for HF inference provider

Browse files
Files changed (2) hide show
  1. params.cfg +2 -2
  2. utils/generator.py +10 -2
params.cfg CHANGED
@@ -1,8 +1,8 @@
1
  [generator]
2
  PROVIDER = huggingface
3
  MODEL = meta-llama/Meta-Llama-3-8B-Instruct
4
- MAX_TOKENS = 512
5
- TEMPERATURE = 0.2
6
  INFERENCE_PROVIDER = novita
7
  ORGANIZATION = GIZ
8
 
 
1
  [generator]
2
  PROVIDER = huggingface
3
  MODEL = meta-llama/Meta-Llama-3-8B-Instruct
4
+ MAX_TOKENS = 768
5
+ TEMPERATURE = 0
6
  INFERENCE_PROVIDER = novita
7
  ORGANIZATION = GIZ
8
 
utils/generator.py CHANGED
@@ -32,6 +32,8 @@ PROVIDER = config.get("generator", "PROVIDER")
32
  MODEL = config.get("generator", "MODEL")
33
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
34
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
 
 
35
 
36
  # Set up authentication for the selected provider
37
  auth_config = get_auth(PROVIDER)
@@ -45,8 +47,14 @@ def _get_chat_model():
45
  "anthropic": lambda: ChatAnthropic(model=MODEL, anthropic_api_key=auth_config["api_key"], streaming=True, **common_params),
46
  "cohere": lambda: ChatCohere(model=MODEL, cohere_api_key=auth_config["api_key"], streaming=True, **common_params),
47
  "huggingface": lambda: ChatHuggingFace(llm=HuggingFaceEndpoint(
48
- repo_id=MODEL, huggingfacehub_api_token=auth_config["api_key"],
49
- task="text-generation", temperature=TEMPERATURE, max_new_tokens=MAX_TOKENS, streaming=True
 
 
 
 
 
 
50
  ))
51
  }
52
 
 
32
  MODEL = config.get("generator", "MODEL")
33
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
34
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
35
+ INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
36
+ ORGANIZATION = config.get("generator", "ORGANIZATION")
37
 
38
  # Set up authentication for the selected provider
39
  auth_config = get_auth(PROVIDER)
 
47
  "anthropic": lambda: ChatAnthropic(model=MODEL, anthropic_api_key=auth_config["api_key"], streaming=True, **common_params),
48
  "cohere": lambda: ChatCohere(model=MODEL, cohere_api_key=auth_config["api_key"], streaming=True, **common_params),
49
  "huggingface": lambda: ChatHuggingFace(llm=HuggingFaceEndpoint(
50
+ repo_id=MODEL,
51
+ huggingfacehub_api_token=auth_config["api_key"],
52
+ task="text-generation",
53
+ provider=INFERENCE_PROVIDER,
54
+ server_kwargs={"bill_to": ORGANIZATION},
55
+ temperature=TEMPERATURE,
56
+ max_new_tokens=MAX_TOKENS,
57
+ streaming=True
58
  ))
59
  }
60