Spaces:
Sleeping
Sleeping
from huggingface_hub import InferenceClient | |
import json | |
from bs4 import BeautifulSoup | |
import requests | |
import gradio as gr | |
from model import llm_models, llm_serverless_models | |
from prompt import llm_system_prompt | |
llm_clients = {} | |
client_main = None | |
current_model = None | |
language_codes = {"English": "en", "Japanese": "ja", "Chinese": "zh"} | |
llm_languages = ["language same as user input"] + list(language_codes.keys()) | |
llm_output_language = "language same as user input" | |
llm_sysprompt_mode = "Default" | |
server_timeout = 300 | |
def get_llm_sysprompt(): | |
import re | |
prompt = re.sub('<LANGUAGE>', llm_output_language, llm_system_prompt.get(llm_sysprompt_mode, "")) | |
return prompt | |
def get_llm_sysprompt_mode(): | |
return list(llm_system_prompt.keys()) | |
def set_llm_sysprompt_mode(key: str): | |
global llm_sysprompt_mode | |
if not key in llm_system_prompt.keys(): | |
llm_sysprompt_mode = "Default" | |
else: | |
llm_sysprompt_mode = key | |
return gr.update(value=get_llm_sysprompt()) | |
def get_llm_language(): | |
return llm_languages | |
def set_llm_language(lang: str): | |
global llm_output_language | |
llm_output_language = lang | |
return gr.update(value=get_llm_sysprompt()) | |
def get_llm_model_info(model_name): | |
return f'Repo: [{model_name}](https://huggingface.co/{model_name})' | |
# Function to extract text from a webpage | |
def get_text_from_html(html_content): | |
soup = BeautifulSoup(html_content, 'html.parser') | |
for tag in soup(["script", "style", "header", "footer"]): | |
tag.extract() | |
return soup.get_text(strip=True) | |
# Function to perform a web search | |
def get_language_code(s): | |
from langdetect import detect | |
lang = "en" | |
if llm_output_language == "language same as user input": | |
lang = detect(s) | |
elif llm_output_language in language_codes.keys(): | |
lang = language_codes[llm_output_language] | |
return lang | |
def perform_search(query): | |
import urllib3 | |
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
search_term = query | |
lang = get_language_code(search_term) | |
all_results = [] | |
max_chars_per_page = 8000 | |
with requests.Session() as session: | |
response = session.get( | |
url="https://www.google.com/search", | |
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"}, | |
params={"q": search_term, "num": 3, "udm": 14, "hl": f"{lang}", "lr": f"lang_{lang}", "safe": "off", "pws": 0}, | |
timeout=5, | |
verify=False, | |
) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.text, "html.parser") | |
result_block = soup.find_all("div", attrs={"class": "g"}) | |
for result in result_block: | |
link = result.find("a", href=True)["href"] | |
try: | |
webpage_response = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"}, timeout=5, verify=False) | |
webpage_response.raise_for_status() | |
visible_text = get_text_from_html(webpage_response.text) | |
if len(visible_text) > max_chars_per_page: | |
visible_text = visible_text[:max_chars_per_page] | |
all_results.append({"link": link, "text": visible_text}) | |
except requests.exceptions.RequestException: | |
all_results.append({"link": link, "text": None}) | |
return all_results | |
# https://github.com/gradio-app/gradio/blob/main/gradio/external.py | |
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client | |
def load_from_model(model_name: str, hf_token: str = None): | |
import httpx | |
import huggingface_hub | |
from gradio.exceptions import ModelNotFoundError | |
model_url = f"https://huggingface.co/{model_name}" | |
api_url = f"https://api-inference.huggingface.co/models/{model_name}" | |
print(f"Fetching model from: {model_url}") | |
headers = {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {} | |
response = httpx.request("GET", api_url, headers=headers) | |
if response.status_code != 200: | |
raise ModelNotFoundError( | |
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter." | |
) | |
headers["X-Wait-For-Model"] = "true" | |
client = huggingface_hub.InferenceClient(model=model_name, headers=headers, | |
token=hf_token, timeout=server_timeout) | |
inputs = [ | |
gr.components.Textbox(render=False), | |
gr.components.State(render=False), | |
] | |
outputs = [ | |
gr.components.Chatbot(render=False), | |
gr.components.State(render=False), | |
] | |
fn = client.chat_completion | |
def query_huggingface_inference_endpoints(*data, **kwargs): | |
return fn(*data, **kwargs) | |
interface_info = { | |
"fn": query_huggingface_inference_endpoints, | |
"inputs": inputs, | |
"outputs": outputs, | |
"title": model_name, | |
} | |
return gr.Interface(**interface_info) | |
def get_status(model_name: str): | |
client = InferenceClient(timeout=10) | |
return client.get_model_status(model_name) | |
def load_clients(): | |
global llm_clients | |
for model in llm_serverless_models: | |
status = get_status(model) | |
#print(f"HF model status: {status}") | |
if status is None or status.state not in ["Loadable", "Loaded"]: # | |
print(f"Failed to load by serverless inference API: {model}. Model state is {status.state}") | |
continue | |
try: | |
print(f"Fetching model by serverless inference API: {model}") | |
llm_clients[model] = InferenceClient(model) | |
except Exception as e: | |
print(e) | |
print(f"Failed to load by serverless inference API: {model}") | |
continue | |
print(f"Loaded by serverless inference API: {model}") | |
for model in llm_models: | |
if model in llm_clients.keys(): continue | |
status = get_status(model) | |
#print(f"HF model status: {status}") | |
if status is None or status.state not in ["Loadable", "Loaded"]: # | |
print(f"Failed to load: {model}. Model state is {status.state}") | |
continue | |
try: | |
llm_clients[model] = load_from_model(model) | |
except Exception as e: | |
print(e) | |
print(f"Failed to load: {model}") | |
continue | |
print(f"Loaded: {model}") | |
def add_client(model_name: str): | |
global llm_clients | |
try: | |
status = get_status(model_name) | |
#print(f"HF model status: {status}") | |
if status is None or status.state not in ["Loadable", "Loaded"]: # | |
print(f"Failed to load: {model_name}. Model state is {status.state}") | |
new_client = None | |
else: new_client = InferenceClient(model_name) | |
except Exception as e: | |
print(e) | |
new_client = None | |
if new_client: | |
print(f"Loaded by serverless inference API: {model_name}") | |
llm_clients[model_name] = new_client | |
return new_client | |
else: | |
print(f"Failed to load: {model_name}") | |
return llm_clients.get(llm_serverless_models[0], None) | |
def set_llm_model(model_name: str = llm_serverless_models[0]): | |
global client_main | |
global current_model | |
if model_name in llm_clients.keys(): | |
client_main = llm_clients.get(model_name, None) | |
else: | |
client_main = add_client(model_name) | |
if client_main is not None: | |
current_model = model_name | |
print(f"Model selected: {model_name}") | |
print(f"HF model status: {get_status(model_name)}") | |
return model_name, get_llm_model_info(model_name) | |
else: return None, "None" | |
def get_llm_model(): | |
return list(llm_clients.keys()) | |
# Initialize inference clients | |
load_clients() | |
set_llm_model() | |
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") | |
# https://huggingface.co/docs/huggingface_hub/v0.24.5/en/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion | |
def chat_body(message, history, query, tokens, temperature, top_p, fpenalty, web_summary): | |
system_prompt = get_llm_sysprompt() | |
if query and web_summary: | |
messages = [] | |
messages.append({"role": "system", "content": system_prompt}) | |
for msg in history: | |
messages.append({"role": "user", "content": str(msg[0])}) | |
messages.append({"role": "assistant", "content": str(msg[1])}) | |
messages.append({"role": "user", "content": f"{message}\nweb_result\n{web_summary}"}) | |
messages.append({"role": "assistant", "content": ""}) | |
try: | |
if isinstance(client_main, gr.Interface): | |
stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature, | |
top_p=top_p, frequency_penalty=fpenalty, stream=True) | |
else: | |
stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature, | |
top_p=top_p, stream=True) | |
except Exception as e: | |
print(e) | |
stream = [] | |
output = "" | |
for response in stream: | |
if response and response.choices and response.choices[0].delta.content is not None: | |
output += response.choices[0].delta.content | |
yield [(output, None)] | |
else: | |
messages = [] | |
messages.append({"role": "system", "content": system_prompt}) | |
for msg in history: | |
messages.append({"role": "user", "content": str(msg[0])}) | |
messages.append({"role": "assistant", "content": str(msg[1])}) | |
messages.append({"role": "user", "content": message}) | |
messages.append({"role": "assistant", "content": ""}) | |
try: | |
if isinstance(client_main, gr.Interface): | |
stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature, | |
top_p=top_p, stream=True) | |
else: | |
stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature, | |
top_p=top_p, stream=True) | |
except Exception as e: | |
print(e) | |
stream = [] | |
output = "" | |
for response in stream: | |
if response and response.choices and response.choices[0].delta.content is not None: | |
output += response.choices[0].delta.content | |
yield [(output, None)] | |
def get_web_summary(history, query_message): | |
if not query_message: return "" | |
func_calls = [] | |
functions_metadata = [ | |
{"type": "function", "function": {"name": "web_search", "description": "Search query on Google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Web search query"}}, "required": ["query"]}}}, | |
] | |
for msg in history: | |
func_calls.append({"role": "user", "content": f"{str(msg[0])}"}) | |
func_calls.append({"role": "assistant", "content": f"{str(msg[1])}"}) | |
func_calls.append({"role": "user", "content": f'[SYSTEM] You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {query_message}'}) | |
response = client_gemma.chat_completion(func_calls, max_tokens=200) | |
response = str(response) | |
try: | |
response = response[int(response.find("{")):int(response.rindex("}"))+1] | |
except: | |
response = response[int(response.find("{")):(int(response.rfind("}"))+1)] | |
response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '') | |
#print(f"\n{response}") | |
try: | |
json_data = json.loads(str(response)) | |
if json_data["name"] == "web_search": | |
query = json_data["arguments"]["query"] | |
#gr.Info("Searching Web") | |
web_results = perform_search(query) | |
#gr.Info("Extracting relevant Info") | |
web_summary = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res['text']]) | |
return web_summary | |
else: | |
return "" | |
except: | |
return "" | |
# Function to handle responses | |
def chat_response(message, history, query, tokens, temperature, top_p, fpenalty): | |
if history is None: history = [] | |
yield from chat_body(message, history, query, tokens, temperature, top_p, fpenalty, get_web_summary(history, query)) | |