mgbam commited on
Commit
96b76f8
·
verified ·
1 Parent(s): 6a1db5c

Update hf_client.py

Browse files
Files changed (1) hide show
  1. hf_client.py +12 -45
hf_client.py CHANGED
@@ -1,61 +1,28 @@
1
- import os
2
- from huggingface_hub import InferenceClient
3
- from tavily import TavilyClient
4
-
5
- # === API Keys ===
6
- HF_TOKEN = os.getenv('HF_TOKEN')
7
- GROQ_API_KEY = os.getenv('GROQ_API_KEY')
8
- TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
9
 
10
- if not HF_TOKEN:
11
- raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.")
12
 
13
- # === GROQ-Compatible Wrapper ===
14
- class GroqChatClient:
15
- def __init__(self, api_key: str):
16
- import openai
17
- openai.api_key = api_key
18
- openai.api_base = "https://api.groq.com/openai/v1"
19
- self.client = openai
20
- self.chat = self.Chat(openai)
21
 
22
- class Chat:
23
- def __init__(self, openai_client):
24
- self.completions = self.Completions(openai_client)
25
 
26
- class Completions:
27
- def __init__(self, openai_client):
28
- self.client = openai_client
29
 
30
- def create(self, model, messages, temperature=0.7, max_tokens=1024, **kwargs):
31
- return self.client.ChatCompletion.create(
32
- model=model,
33
- messages=messages,
34
- temperature=temperature,
35
- max_tokens=max_tokens,
36
- **kwargs
37
- )
38
 
39
- # === Inference Client Selector ===
40
- def get_inference_client(model_id: str, provider: str = "auto"):
41
- """
42
- Returns a unified interface:
43
- - For 'moonshotai/Kimi-K2-Instruct', uses Groq with OpenAI-compatible API
44
- - For others, uses Hugging Face InferenceClient
45
- """
46
  if model_id == "moonshotai/Kimi-K2-Instruct":
47
- if not GROQ_API_KEY:
48
- raise RuntimeError("GROQ_API_KEY is required for Groq-hosted models.")
49
- return GroqChatClient(api_key=GROQ_API_KEY)
50
-
51
  return InferenceClient(
52
- model=model_id,
53
  provider=provider,
54
  api_key=HF_TOKEN,
55
  bill_to="huggingface"
56
  )
57
 
58
- # === Tavily Search Client ===
 
59
  tavily_client = None
60
  if TAVILY_API_KEY:
61
  try:
 
 
 
 
 
 
 
 
 
1
 
 
 
2
 
 
 
 
 
 
 
 
 
3
 
4
+ ### hf_client.py
 
 
5
 
6
+ from huggingface_hub import InferenceClient, HfApi
7
+ from tavily import TavilyClient
 
8
 
9
+ # HF Inference Client
10
+ HF_TOKEN = os.getenv('HF_TOKEN')
11
+ if not HF_TOKEN:
12
+ raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.")
 
 
 
 
13
 
14
+ def get_inference_client(model_id, provider="auto"):
15
+ """Return an InferenceClient with provider based on model_id and user selection."""
 
 
 
 
 
16
  if model_id == "moonshotai/Kimi-K2-Instruct":
17
+ provider = "groq"
 
 
 
18
  return InferenceClient(
 
19
  provider=provider,
20
  api_key=HF_TOKEN,
21
  bill_to="huggingface"
22
  )
23
 
24
+ # Tavily Search Client
25
+ TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
26
  tavily_client = None
27
  if TAVILY_API_KEY:
28
  try: