mgbam commited on
Commit
ac4a3a2
·
verified ·
1 Parent(s): 2541fb7

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +23 -5
hf_client.py CHANGED
@@ -5,18 +5,36 @@ from tavily import TavilyClient
5
  import os
6
 
7
  # HF Inference Client
8
- HF_TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
9
  if not HF_TOKEN:
10
- raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.")
 
 
 
11
 
12
- def get_inference_client(model_id, provider="auto"):
13
- """Return an InferenceClient with provider based on model_id and user selection."""
 
 
 
 
 
 
 
14
  if model_id == "moonshotai/Kimi-K2-Instruct":
15
  provider = "groq"
 
 
 
 
16
  return InferenceClient(
17
  provider=provider,
18
  api_key=HF_TOKEN,
19
- bill_to="huggingface"
20
  )
21
 
22
  # Tavily Search Client
 
5
  import os
6
 
7
  # HF Inference Client
8
+
9
+ # Supported billing targets
10
+ _VALID_BILL_TO = {"huggingface", "fairworksai", "groq"}
11
+
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
  if not HF_TOKEN:
14
+ raise RuntimeError(
15
+ "HF_TOKEN environment variable is not set. "
16
+ "Please set it to your Hugging Face API token."
17
+ )
18
 
19
+ def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
20
+ """
21
+ Return an InferenceClient configured with the correct provider and billing target.
22
+
23
+ - If model_id == "moonshotai/Kimi-K2-Instruct", force provider to "groq".
24
+ - If the requested provider is not one of the supported billing targets,
25
+ default billing to "groq".
26
+ """
27
+ # force certain models onto groq hardware
28
  if model_id == "moonshotai/Kimi-K2-Instruct":
29
  provider = "groq"
30
+
31
+ # determine billing target
32
+ bill_to = provider if provider in _VALID_BILL_TO else "groq"
33
+
34
  return InferenceClient(
35
  provider=provider,
36
  api_key=HF_TOKEN,
37
+ bill_to=bill_to
38
  )
39
 
40
  # Tavily Search Client