StreamlitChat / chat_client.py
MatteoScript's picture
Update chat_client.py
fee6f56 verified
raw
history blame
3.15 kB
from huggingface_hub import InferenceClient
import os
from dotenv import load_dotenv
import random
import json
from openai import OpenAI
load_dotenv()
API_TOKEN = os.getenv('HF_TOKEN')
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def format_prompt_openai(system_prompt, message, history):
messages = []
if system_prompt != '':
messages.append({"role": "system", "content": system_prompt})
for user_prompt, bot_response in history:
messages.append({"role": "user", "content": user_prompt})
messages.append({"role": "assistant", "content": bot_response})
messages.append({"role": "user", "content": message})
return messages
def chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty):
client = InferenceClient(
chat_client,
token=API_TOKEN
)
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(0, 10**7),
)
formatted_prompt = format_prompt(prompt, history)
print('***************************************************')
print(formatted_prompt)
print('***************************************************')
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
return stream
def chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai):
try:
prompt = prompt.replace('\n', '')
json_data = json.loads(prompt)
user_prompt = json_data["messages"][1]["content"]
system_prompt = json_data["input"]["content"]
system_style = json_data["input"]["style"]
instructions = json_data["messages"][0]["content"]
if instructions != '':
system_prompt += '\n' + instructions
if system_style != '':
system_prompt += '\n' + system_style
except:
user_prompt = prompt
system_prompt = ''
messages = format_prompt_openai(system_prompt, user_prompt, history)
print('***************************************************')
print(messages)
print('***************************************************')
stream = client_openai.chat.completions.create(
model=chat_client,
stream=True,
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
)
return stream
def chat(prompt, history, chat_client,temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0, client_openai = None):
if chat_client[:3] == 'gpt':
return chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai)
else:
return chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty)