demo-multiturn / helpers /inference.py
ashwath-vaithina-ibm's picture
Upload inference.py
aefc33c verified
from helpers import get_credentials
import requests
def hf_inference(prompt, model_id, temperature, max_new_tokens):
hf_token, _ = get_credentials.get_hf_credentials()
API_URL = "https://router.huggingface.co/together/v1/chat/completions"
headers = {
"Authorization": f"Bearer {hf_token}",
}
response = requests.post(
API_URL,
headers=headers,
json={
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
]
}
],
"model": model_id,
'temperature': temperature,
'max_new_tokens': max_new_tokens,
}
)
return response.json()["choices"][0]["message"]
def replicate_inference(prompt, model_id, temperature, max_new_tokens):
repl_token = get_credentials.get_replicate_credentials()
API_URL = f"https://api.replicate.com/v1/models/{model_id}/predictions"
headers = {
"Authorization": f"Bearer {repl_token}",
"Content-Type": "application/json",
"Prefer": "wait"
}
response = requests.post(
API_URL,
headers=headers,
json={
"input": {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_new_tokens,
}
}
)
return {
"content": "".join(response.json()['output'])
}
INFERENCE_HANDLER = {
'huggingface': hf_inference,
'replicate': replicate_inference
}