MatteoScript commited on
Commit
3d4b515
1 Parent(s): ed5749a

Update chat_client.py

Browse files
Files changed (1) hide show
  1. chat_client.py +66 -32
chat_client.py CHANGED
@@ -2,12 +2,12 @@ from huggingface_hub import InferenceClient
2
  import os
3
  from dotenv import load_dotenv
4
  import random
 
 
5
 
6
  load_dotenv()
7
-
8
  API_TOKEN = os.getenv('HF_TOKEN')
9
 
10
-
11
  def format_prompt(message, history):
12
  prompt = "<s>"
13
  for user_prompt, bot_response in history:
@@ -16,37 +16,71 @@ def format_prompt(message, history):
16
  prompt += f"[INST] {message} [/INST]"
17
  return prompt
18
 
19
- def chat(
20
- prompt, history, chat_client = "mistralai/Mistral-7B-Instruct-v0.1",temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0,
21
- ):
22
- client = InferenceClient(
23
- chat_client,
24
- token=API_TOKEN
25
- )
26
- temperature = float(temperature)
27
- if temperature < 1e-2:
28
- temperature = 1e-2
29
- top_p = float(top_p)
30
-
31
- generate_kwargs = dict(
32
- temperature=temperature,
33
- max_new_tokens=max_new_tokens,
34
- top_p=top_p,
35
- repetition_penalty=repetition_penalty,
36
- do_sample=True,
37
- seed=random.randint(0, 10**7),
38
- )
39
-
40
- formatted_prompt = format_prompt(prompt, history)
41
 
42
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
43
- output = ""
 
 
 
 
 
 
 
44
 
45
- # for response in stream:
46
- # # print(response)
47
- # output += response.token["text"]
48
- # yield output
49
- # return output
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- return stream
 
 
 
 
 
2
  import os
3
  from dotenv import load_dotenv
4
  import random
5
+ import json
6
+ from openai import OpenAI
7
 
8
  load_dotenv()
 
9
  API_TOKEN = os.getenv('HF_TOKEN')
10
 
 
11
  def format_prompt(message, history):
12
  prompt = "<s>"
13
  for user_prompt, bot_response in history:
 
16
  prompt += f"[INST] {message} [/INST]"
17
  return prompt
18
 
19
+ def format_prompt_openai(system_prompt, message, history):
20
+ messages = []
21
+ if system_prompt != '':
22
+ messages.append({"role": "system", "content": system_prompt})
23
+ for user_prompt, bot_response in history:
24
+ messages.append({"role": "user", "content": user_prompt})
25
+ messages.append({"role": "assistant", "content": bot_response})
26
+ messages.append({"role": "user", "content": message})
27
+ return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty):
30
+ client = InferenceClient(
31
+ chat_client,
32
+ token=API_TOKEN
33
+ )
34
+ temperature = float(temperature)
35
+ if temperature < 1e-2:
36
+ temperature = 1e-2
37
+ top_p = float(top_p)
38
 
39
+ generate_kwargs = dict(
40
+ temperature=temperature,
41
+ max_new_tokens=max_new_tokens,
42
+ top_p=top_p,
43
+ repetition_penalty=repetition_penalty,
44
+ do_sample=True,
45
+ seed=random.randint(0, 10**7),
46
+ )
47
+ formatted_prompt = format_prompt(prompt, history)
48
+ print('***************************************************')
49
+ print(formatted_prompt)
50
+ print('***************************************************')
51
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
52
+ return stream
53
 
54
+ def chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai):
55
+ try:
56
+ prompt = prompt.replace('\n', '')
57
+ json_data = json.loads(prompt)
58
+ user_prompt = json_data["messages"][1]["content"]
59
+ system_prompt = json_data["input"]["content"]
60
+ system_style = json_data["input"]["style"]
61
+ instructions = json_data["messages"][0]["content"]
62
+ if instructions != '':
63
+ system_prompt += '\n' + instructions
64
+ if system_style != '':
65
+ system_prompt += '\n' + system_style
66
+ except:
67
+ user_prompt = prompt
68
+ system_prompt = ''
69
+ messages = format_prompt_openai(system_prompt, user_prompt, history)
70
+ print('***************************************************')
71
+ print(messages)
72
+ print('***************************************************')
73
+ stream = client_openai.chat.completions.create(
74
+ model=chat_client,
75
+ stream=True,
76
+ messages=messages,
77
+ temperature=temperature,
78
+ max_tokens=max_new_tokens,
79
+ )
80
+ return stream
81
 
82
+ def chat(prompt, history, chat_client,temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0, client_openai = None):
83
+ if chat_client[:3] == 'gpt':
84
+ return chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai)
85
+ else:
86
+ return chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty)