scamgptj-eval / huggingface.py
ksee's picture
Upload 2 files
9c7155b
from huggingface_hub import InferenceClient
import requests
from helpers import wrap_conversation
class Model:
def __init__(self, model_name: str, API_URL: str):
self.model_name = model_name
self.API_URL = API_URL
def _query(self, payload):
headers = {
# "Authorization": "Bearer XXX",
"Content-Type": "application/json"
}
response = requests.post(self.API_URL, headers=headers, json=payload)
return response.json()
def generate_text(self, text_input: str) -> str:
# output = self.client.text_generation(text_input)
# return output
payload = {
"inputs": wrap_conversation(text_input),
"parameters": {
"temperature": 0.9,
"max_new_tokens": 50
}
}
output = self._query(payload)
return self._get_first_reply(text_input, output['generated_text'])
@staticmethod
def _get_first_reply(original_text, output_text):
new_text = output_text[len(wrap_conversation(original_text)):]
final_text = new_text.split('#', 1)[0]
return final_text