File size: 3,153 Bytes
2006c2b
 
 
ed9dba7
fee6f56
 
ed9dba7
2006c2b
 
 
 
 
 
 
 
 
 
 
fee6f56
 
 
 
 
 
 
 
 
2006c2b
fee6f56
 
 
 
 
 
 
 
 
2006c2b
fee6f56
 
 
 
 
 
 
 
 
 
 
 
 
 
2e2d510
fee6f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e2d510
fee6f56
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from huggingface_hub import InferenceClient
import os
from dotenv import load_dotenv
import random
import json
from openai import OpenAI

load_dotenv()
API_TOKEN = os.getenv('HF_TOKEN')

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def format_prompt_openai(system_prompt, message, history):
  messages = []
  if system_prompt != '':
    messages.append({"role": "system", "content": system_prompt})
  for user_prompt, bot_response in history:
    messages.append({"role": "user", "content": user_prompt})
    messages.append({"role": "assistant", "content": bot_response})
  messages.append({"role": "user", "content": message})  
  return messages 

def chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty): 
  client = InferenceClient(
      chat_client,
      token=API_TOKEN
  )
  temperature = float(temperature)
  if temperature < 1e-2:
      temperature = 1e-2
  top_p = float(top_p)

  generate_kwargs = dict(
      temperature=temperature,
      max_new_tokens=max_new_tokens,
      top_p=top_p,
      repetition_penalty=repetition_penalty,
      do_sample=True,
      seed=random.randint(0, 10**7),
  )
  formatted_prompt = format_prompt(prompt, history)
  print('***************************************************')
  print(formatted_prompt)
  print('***************************************************')
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
  return stream   

def chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai): 
  try: 
    prompt = prompt.replace('\n', '')
    json_data = json.loads(prompt)
    user_prompt = json_data["messages"][1]["content"]
    system_prompt = json_data["input"]["content"]
    system_style = json_data["input"]["style"]
    instructions = json_data["messages"][0]["content"]
    if instructions != '':
        system_prompt += '\n' + instructions
    if system_style != '':
        system_prompt += '\n' + system_style
  except: 
    user_prompt = prompt
    system_prompt = ''
  messages = format_prompt_openai(system_prompt, user_prompt, history)
  print('***************************************************')
  print(messages)
  print('***************************************************')
  stream = client_openai.chat.completions.create(
      model=chat_client,
      stream=True,
      messages=messages, 
      temperature=temperature,
      max_tokens=max_new_tokens,
  )  
  return stream

def chat(prompt, history, chat_client,temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0, client_openai = None):
  if chat_client[:3] == 'gpt': 
    return chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai)   
  else: 
    return chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty)