NCTCMumbai commited on
Commit
32bfafa
·
verified ·
1 Parent(s): 91fad79

Update backend/query_llm.py

Browse files
Files changed (1) hide show
  1. backend/query_llm.py +161 -161
backend/query_llm.py CHANGED
@@ -1,161 +1,161 @@
1
-
2
-
3
- import openai
4
- import gradio as gr
5
-
6
- from os import getenv
7
- from typing import Any, Dict, Generator, List
8
-
9
- from huggingface_hub import InferenceClient
10
- from transformers import AutoTokenizer
11
-
12
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
14
- #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")
15
- temperature = 0.5
16
- top_p = 0.7
17
- repetition_penalty = 1.2
18
-
19
- OPENAI_KEY = getenv("OPENAI_API_KEY")
20
- HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
21
-
22
- # hf_client = InferenceClient(
23
- # "mistralai/Mistral-7B-Instruct-v0.1",
24
- # token=HF_TOKEN
25
- # )
26
-
27
-
28
- hf_client = InferenceClient(
29
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
30
- token=HF_TOKEN
31
- )
32
- def format_prompt(message: str, api_kind: str):
33
- """
34
- Formats the given message using a chat template.
35
-
36
- Args:
37
- message (str): The user message to be formatted.
38
-
39
- Returns:
40
- str: Formatted message after applying the chat template.
41
- """
42
-
43
- # Create a list of message dictionaries with role and content
44
- messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
45
-
46
- if api_kind == "openai":
47
- return messages
48
- elif api_kind == "hf":
49
- return tokenizer.apply_chat_template(messages, tokenize=False)
50
- elif api_kind:
51
- raise ValueError("API is not supported")
52
-
53
-
54
- def generate_hf(prompt: str, history: str, temperature: float = 0.5, max_new_tokens: int = 4000,
55
- top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
56
- """
57
- Generate a sequence of tokens based on a given prompt and history using Mistral client.
58
-
59
- Args:
60
- prompt (str): The initial prompt for the text generation.
61
- history (str): Context or history for the text generation.
62
- temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
63
- max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
64
- top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
65
- repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
66
-
67
- Returns:
68
- Generator[str, None, str]: A generator yielding chunks of generated text.
69
- Returns a final string if an error occurs.
70
- """
71
-
72
- temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
73
- top_p = float(top_p)
74
-
75
- generate_kwargs = {
76
- 'temperature': temperature,
77
- 'max_new_tokens': max_new_tokens,
78
- 'top_p': top_p,
79
- 'repetition_penalty': repetition_penalty,
80
- 'do_sample': True,
81
- 'seed': 42,
82
- }
83
-
84
- formatted_prompt = format_prompt(prompt, "hf")
85
-
86
- try:
87
- stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
88
- stream=True, details=True, return_full_text=False)
89
- output = ""
90
- for response in stream:
91
- output += response.token.text
92
- yield output
93
-
94
- except Exception as e:
95
- if "Too Many Requests" in str(e):
96
- print("ERROR: Too many requests on Mistral client")
97
- gr.Warning("Unfortunately Mistral is unable to process")
98
- return "Unfortunately, I am not able to process your request now."
99
- elif "Authorization header is invalid" in str(e):
100
- print("Authetification error:", str(e))
101
- gr.Warning("Authentication error: HF token was either not provided or incorrect")
102
- return "Authentication error"
103
- else:
104
- print("Unhandled Exception:", str(e))
105
- gr.Warning("Unfortunately Mistral is unable to process")
106
- return "I do not know what happened, but I couldn't understand you."
107
-
108
-
109
- def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
110
- top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
111
- """
112
- Generate a sequence of tokens based on a given prompt and history using Mistral client.
113
-
114
- Args:
115
- prompt (str): The initial prompt for the text generation.
116
- history (str): Context or history for the text generation.
117
- temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
118
- max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
119
- top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
120
- repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
121
-
122
- Returns:
123
- Generator[str, None, str]: A generator yielding chunks of generated text.
124
- Returns a final string if an error occurs.
125
- """
126
-
127
- temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
128
- top_p = float(top_p)
129
-
130
- generate_kwargs = {
131
- 'temperature': temperature,
132
- 'max_tokens': max_new_tokens,
133
- 'top_p': top_p,
134
- 'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
135
- }
136
-
137
- formatted_prompt = format_prompt(prompt, "openai")
138
-
139
- try:
140
- stream = openai.ChatCompletion.create(model="gpt-3.5-turbo-0301",
141
- messages=formatted_prompt,
142
- **generate_kwargs,
143
- stream=True)
144
- output = ""
145
- for chunk in stream:
146
- output += chunk.choices[0].delta.get("content", "")
147
- yield output
148
-
149
- except Exception as e:
150
- if "Too Many Requests" in str(e):
151
- print("ERROR: Too many requests on OpenAI client")
152
- gr.Warning("Unfortunately OpenAI is unable to process")
153
- return "Unfortunately, I am not able to process your request now."
154
- elif "You didn't provide an API key" in str(e):
155
- print("Authetification error:", str(e))
156
- gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
157
- return "Authentication error"
158
- else:
159
- print("Unhandled Exception:", str(e))
160
- gr.Warning("Unfortunately OpenAI is unable to process")
161
- return "I do not know what happened, but I couldn't understand you."
 
1
+
2
+
3
+ import openai
4
+ import gradio as gr
5
+
6
+ from os import getenv
7
+ from typing import Any, Dict, Generator, List
8
+
9
+ from huggingface_hub import InferenceClient
10
+ from transformers import AutoTokenizer
11
+
12
+ #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
14
+ #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")
15
+ temperature = 0.5
16
+ top_p = 0.7
17
+ repetition_penalty = 1.2
18
+
19
+ OPENAI_KEY = getenv("OPENAI_API_KEY")
20
+ HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
21
+
22
+ # hf_client = InferenceClient(
23
+ # "mistralai/Mistral-7B-Instruct-v0.1",
24
+ # token=HF_TOKEN
25
+ # )
26
+
27
+
28
+ hf_client = InferenceClient(
29
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
30
+ token=HF_TOKEN
31
+ )
32
+ def format_prompt(message: str, api_kind: str):
33
+ """
34
+ Formats the given message using a chat template.
35
+
36
+ Args:
37
+ message (str): The user message to be formatted.
38
+
39
+ Returns:
40
+ str: Formatted message after applying the chat template.
41
+ """
42
+
43
+ # Create a list of message dictionaries with role and content
44
+ messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
45
+
46
+ if api_kind == "openai":
47
+ return messages
48
+ elif api_kind == "hf":
49
+ return tokenizer.apply_chat_template(messages, tokenize=False)
50
+ elif api_kind:
51
+ raise ValueError("API is not supported")
52
+
53
+
54
+ def generate_hf(prompt: str, history: str, temperature: float = 0.5, max_new_tokens: int = 4000,
55
+ top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
56
+ """
57
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
58
+
59
+ Args:
60
+ prompt (str): The initial prompt for the text generation.
61
+ history (str): Context or history for the text generation.
62
+ temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
63
+ max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
64
+ top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
65
+ repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
66
+
67
+ Returns:
68
+ Generator[str, None, str]: A generator yielding chunks of generated text.
69
+ Returns a final string if an error occurs.
70
+ """
71
+
72
+ temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
73
+ top_p = float(top_p)
74
+
75
+ generate_kwargs = {
76
+ 'temperature': temperature,
77
+ 'max_new_tokens': max_new_tokens,
78
+ 'top_p': top_p,
79
+ 'repetition_penalty': repetition_penalty,
80
+ 'do_sample': True,
81
+ 'seed': 42,
82
+ }
83
+
84
+ formatted_prompt = format_prompt(prompt, "hf")
85
+
86
+ try:
87
+ stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
88
+ stream=True, details=True, return_full_text=False)
89
+ output = ""
90
+ for response in stream:
91
+ output += response.token.text
92
+ yield output
93
+
94
+ except Exception as e:
95
+ if "Too Many Requests" in str(e):
96
+ print("ERROR: Too many requests on Mistral client")
97
+ gr.Warning("Unfortunately Mistral is unable to process")
98
+ return "Unfortunately, I am not able to process your request now."
99
+ elif "Authorization header is invalid" in str(e):
100
+ print("Authetification error:", str(e))
101
+ gr.Warning("Authentication error: HF token was either not provided or incorrect")
102
+ return "Authentication error"
103
+ else:
104
+ print("Unhandled Exception:", str(e))
105
+ gr.Warning("Unfortunately Mistral is unable to process")
106
+ return "I do not know what happened, but I couldn't understand you."
107
+
108
+
109
+ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
110
+ top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
111
+ """
112
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
113
+
114
+ Args:
115
+ prompt (str): The initial prompt for the text generation.
116
+ history (str): Context or history for the text generation.
117
+ temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
118
+ max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
119
+ top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
120
+ repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
121
+
122
+ Returns:
123
+ Generator[str, None, str]: A generator yielding chunks of generated text.
124
+ Returns a final string if an error occurs.
125
+ """
126
+
127
+ temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
128
+ top_p = float(top_p)
129
+
130
+ generate_kwargs = {
131
+ 'temperature': temperature,
132
+ 'max_tokens': max_new_tokens,
133
+ 'top_p': top_p,
134
+ 'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
135
+ }
136
+
137
+ formatted_prompt = format_prompt(prompt, "openai")
138
+
139
+ try:
140
+ stream = openai.ChatCompletion.create(model="gpt-3.5-turbo-0301",
141
+ messages=formatted_prompt,
142
+ **generate_kwargs,
143
+ stream=True)
144
+ output = ""
145
+ for chunk in stream:
146
+ output += chunk.choices[0].delta.get("content", "")
147
+ yield output
148
+
149
+ except Exception as e:
150
+ if "Too Many Requests" in str(e):
151
+ print("ERROR: Too many requests on OpenAI client")
152
+ gr.Warning("Unfortunately OpenAI is unable to process")
153
+ return "Unfortunately, I am not able to process your request now."
154
+ elif "You didn't provide an API key" in str(e):
155
+ print("Authetification error:", str(e))
156
+ gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
157
+ return "Authentication error"
158
+ else:
159
+ print("Unhandled Exception:", str(e))
160
+ gr.Warning("Unfortunately OpenAI is unable to process")
161
+ return "I do not know what happened, but I couldn't understand you."