llm-multi-demo / chatllm.py
John6666's picture
Upload 7 files
f788018 verified
from huggingface_hub import InferenceClient
import json
from bs4 import BeautifulSoup
import requests
import gradio as gr
from model import llm_models, llm_serverless_models
from prompt import llm_system_prompt
llm_clients = {}
client_main = None
current_model = None
language_codes = {"English": "en", "Japanese": "ja", "Chinese": "zh"}
llm_languages = ["language same as user input"] + list(language_codes.keys())
llm_output_language = "language same as user input"
llm_sysprompt_mode = "Default"
server_timeout = 300
def get_llm_sysprompt():
import re
prompt = re.sub('<LANGUAGE>', llm_output_language, llm_system_prompt.get(llm_sysprompt_mode, ""))
return prompt
def get_llm_sysprompt_mode():
return list(llm_system_prompt.keys())
def set_llm_sysprompt_mode(key: str):
global llm_sysprompt_mode
if not key in llm_system_prompt.keys():
llm_sysprompt_mode = "Default"
else:
llm_sysprompt_mode = key
return gr.update(value=get_llm_sysprompt())
def get_llm_language():
return llm_languages
def set_llm_language(lang: str):
global llm_output_language
llm_output_language = lang
return gr.update(value=get_llm_sysprompt())
def get_llm_model_info(model_name):
return f'Repo: [{model_name}](https://huggingface.co/{model_name})'
# Function to extract text from a webpage
def get_text_from_html(html_content):
soup = BeautifulSoup(html_content, 'html.parser')
for tag in soup(["script", "style", "header", "footer"]):
tag.extract()
return soup.get_text(strip=True)
# Function to perform a web search
def get_language_code(s):
from langdetect import detect
lang = "en"
if llm_output_language == "language same as user input":
lang = detect(s)
elif llm_output_language in language_codes.keys():
lang = language_codes[llm_output_language]
return lang
def perform_search(query):
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
search_term = query
lang = get_language_code(search_term)
all_results = []
max_chars_per_page = 8000
with requests.Session() as session:
response = session.get(
url="https://www.google.com/search",
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"},
params={"q": search_term, "num": 3, "udm": 14, "hl": f"{lang}", "lr": f"lang_{lang}", "safe": "off", "pws": 0},
timeout=5,
verify=False,
)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
result_block = soup.find_all("div", attrs={"class": "g"})
for result in result_block:
link = result.find("a", href=True)["href"]
try:
webpage_response = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"}, timeout=5, verify=False)
webpage_response.raise_for_status()
visible_text = get_text_from_html(webpage_response.text)
if len(visible_text) > max_chars_per_page:
visible_text = visible_text[:max_chars_per_page]
all_results.append({"link": link, "text": visible_text})
except requests.exceptions.RequestException:
all_results.append({"link": link, "text": None})
return all_results
# https://github.com/gradio-app/gradio/blob/main/gradio/external.py
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
def load_from_model(model_name: str, hf_token: str = None):
import httpx
import huggingface_hub
from gradio.exceptions import ModelNotFoundError
model_url = f"https://huggingface.co/{model_name}"
api_url = f"https://api-inference.huggingface.co/models/{model_name}"
print(f"Fetching model from: {model_url}")
headers = {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {}
response = httpx.request("GET", api_url, headers=headers)
if response.status_code != 200:
raise ModelNotFoundError(
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
)
headers["X-Wait-For-Model"] = "true"
client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
token=hf_token, timeout=server_timeout)
inputs = [
gr.components.Textbox(render=False),
gr.components.State(render=False),
]
outputs = [
gr.components.Chatbot(render=False),
gr.components.State(render=False),
]
fn = client.chat_completion
def query_huggingface_inference_endpoints(*data, **kwargs):
return fn(*data, **kwargs)
interface_info = {
"fn": query_huggingface_inference_endpoints,
"inputs": inputs,
"outputs": outputs,
"title": model_name,
}
return gr.Interface(**interface_info)
def get_status(model_name: str):
client = InferenceClient(timeout=10)
return client.get_model_status(model_name)
def load_clients():
global llm_clients
for model in llm_serverless_models:
status = get_status(model)
#print(f"HF model status: {status}")
if status is None or status.state not in ["Loadable", "Loaded"]: #
print(f"Failed to load by serverless inference API: {model}. Model state is {status.state}")
continue
try:
print(f"Fetching model by serverless inference API: {model}")
llm_clients[model] = InferenceClient(model)
except Exception as e:
print(e)
print(f"Failed to load by serverless inference API: {model}")
continue
print(f"Loaded by serverless inference API: {model}")
for model in llm_models:
if model in llm_clients.keys(): continue
status = get_status(model)
#print(f"HF model status: {status}")
if status is None or status.state not in ["Loadable", "Loaded"]: #
print(f"Failed to load: {model}. Model state is {status.state}")
continue
try:
llm_clients[model] = load_from_model(model)
except Exception as e:
print(e)
print(f"Failed to load: {model}")
continue
print(f"Loaded: {model}")
def add_client(model_name: str):
global llm_clients
try:
status = get_status(model_name)
#print(f"HF model status: {status}")
if status is None or status.state not in ["Loadable", "Loaded"]: #
print(f"Failed to load: {model_name}. Model state is {status.state}")
new_client = None
else: new_client = InferenceClient(model_name)
except Exception as e:
print(e)
new_client = None
if new_client:
print(f"Loaded by serverless inference API: {model_name}")
llm_clients[model_name] = new_client
return new_client
else:
print(f"Failed to load: {model_name}")
return llm_clients.get(llm_serverless_models[0], None)
def set_llm_model(model_name: str = llm_serverless_models[0]):
global client_main
global current_model
if model_name in llm_clients.keys():
client_main = llm_clients.get(model_name, None)
else:
client_main = add_client(model_name)
if client_main is not None:
current_model = model_name
print(f"Model selected: {model_name}")
print(f"HF model status: {get_status(model_name)}")
return model_name, get_llm_model_info(model_name)
else: return None, "None"
def get_llm_model():
return list(llm_clients.keys())
# Initialize inference clients
load_clients()
set_llm_model()
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
# https://huggingface.co/docs/huggingface_hub/v0.24.5/en/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion
def chat_body(message, history, query, tokens, temperature, top_p, fpenalty, web_summary):
system_prompt = get_llm_sysprompt()
if query and web_summary:
messages = []
messages.append({"role": "system", "content": system_prompt})
for msg in history:
messages.append({"role": "user", "content": str(msg[0])})
messages.append({"role": "assistant", "content": str(msg[1])})
messages.append({"role": "user", "content": f"{message}\nweb_result\n{web_summary}"})
messages.append({"role": "assistant", "content": ""})
try:
if isinstance(client_main, gr.Interface):
stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature,
top_p=top_p, frequency_penalty=fpenalty, stream=True)
else:
stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature,
top_p=top_p, stream=True)
except Exception as e:
print(e)
stream = []
output = ""
for response in stream:
if response and response.choices and response.choices[0].delta.content is not None:
output += response.choices[0].delta.content
yield [(output, None)]
else:
messages = []
messages.append({"role": "system", "content": system_prompt})
for msg in history:
messages.append({"role": "user", "content": str(msg[0])})
messages.append({"role": "assistant", "content": str(msg[1])})
messages.append({"role": "user", "content": message})
messages.append({"role": "assistant", "content": ""})
try:
if isinstance(client_main, gr.Interface):
stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature,
top_p=top_p, stream=True)
else:
stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature,
top_p=top_p, stream=True)
except Exception as e:
print(e)
stream = []
output = ""
for response in stream:
if response and response.choices and response.choices[0].delta.content is not None:
output += response.choices[0].delta.content
yield [(output, None)]
def get_web_summary(history, query_message):
if not query_message: return ""
func_calls = []
functions_metadata = [
{"type": "function", "function": {"name": "web_search", "description": "Search query on Google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Web search query"}}, "required": ["query"]}}},
]
for msg in history:
func_calls.append({"role": "user", "content": f"{str(msg[0])}"})
func_calls.append({"role": "assistant", "content": f"{str(msg[1])}"})
func_calls.append({"role": "user", "content": f'[SYSTEM] You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {query_message}'})
response = client_gemma.chat_completion(func_calls, max_tokens=200)
response = str(response)
try:
response = response[int(response.find("{")):int(response.rindex("}"))+1]
except:
response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '')
#print(f"\n{response}")
try:
json_data = json.loads(str(response))
if json_data["name"] == "web_search":
query = json_data["arguments"]["query"]
#gr.Info("Searching Web")
web_results = perform_search(query)
#gr.Info("Extracting relevant Info")
web_summary = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res['text']])
return web_summary
else:
return ""
except:
return ""
# Function to handle responses
def chat_response(message, history, query, tokens, temperature, top_p, fpenalty):
if history is None: history = []
yield from chat_body(message, history, query, tokens, temperature, top_p, fpenalty, get_web_summary(history, query))