Spaces:
Runtime error
Runtime error
| from huggingface_hub import InferenceClient | |
| import json | |
| from bs4 import BeautifulSoup | |
| import requests | |
| import gradio as gr | |
| from model import llm_models, llm_serverless_models | |
| from prompt import llm_system_prompt | |
| llm_clients = {} | |
| client_main = None | |
| current_model = None | |
| language_codes = {"English": "en", "Japanese": "ja", "Chinese": "zh"} | |
| llm_languages = ["language same as user input"] + list(language_codes.keys()) | |
| llm_output_language = "language same as user input" | |
| llm_sysprompt_mode = "Default" | |
| server_timeout = 300 | |
| def get_llm_sysprompt(): | |
| import re | |
| prompt = re.sub('<LANGUAGE>', llm_output_language, llm_system_prompt.get(llm_sysprompt_mode, "")) | |
| return prompt | |
| def get_llm_sysprompt_mode(): | |
| return list(llm_system_prompt.keys()) | |
| def set_llm_sysprompt_mode(key: str): | |
| global llm_sysprompt_mode | |
| if not key in llm_system_prompt.keys(): | |
| llm_sysprompt_mode = "Default" | |
| else: | |
| llm_sysprompt_mode = key | |
| return gr.update(value=get_llm_sysprompt()) | |
| def get_llm_language(): | |
| return llm_languages | |
| def set_llm_language(lang: str): | |
| global llm_output_language | |
| llm_output_language = lang | |
| return gr.update(value=get_llm_sysprompt()) | |
| def get_llm_model_info(model_name): | |
| return f'Repo: [{model_name}](https://huggingface.co/{model_name})' | |
| # Function to extract text from a webpage | |
| def get_text_from_html(html_content): | |
| soup = BeautifulSoup(html_content, 'html.parser') | |
| for tag in soup(["script", "style", "header", "footer"]): | |
| tag.extract() | |
| return soup.get_text(strip=True) | |
| # Function to perform a web search | |
| def get_language_code(s): | |
| from langdetect import detect | |
| lang = "en" | |
| if llm_output_language == "language same as user input": | |
| lang = detect(s) | |
| elif llm_output_language in language_codes.keys(): | |
| lang = language_codes[llm_output_language] | |
| return lang | |
| def perform_search(query): | |
| import urllib3 | |
| urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
| search_term = query | |
| lang = get_language_code(search_term) | |
| all_results = [] | |
| max_chars_per_page = 8000 | |
| with requests.Session() as session: | |
| response = session.get( | |
| url="https://www.google.com/search", | |
| headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"}, | |
| params={"q": search_term, "num": 3, "udm": 14, "hl": f"{lang}", "lr": f"lang_{lang}", "safe": "off", "pws": 0}, | |
| timeout=5, | |
| verify=False, | |
| ) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, "html.parser") | |
| result_block = soup.find_all("div", attrs={"class": "g"}) | |
| for result in result_block: | |
| link = result.find("a", href=True)["href"] | |
| try: | |
| webpage_response = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"}, timeout=5, verify=False) | |
| webpage_response.raise_for_status() | |
| visible_text = get_text_from_html(webpage_response.text) | |
| if len(visible_text) > max_chars_per_page: | |
| visible_text = visible_text[:max_chars_per_page] | |
| all_results.append({"link": link, "text": visible_text}) | |
| except requests.exceptions.RequestException: | |
| all_results.append({"link": link, "text": None}) | |
| return all_results | |
| # https://github.com/gradio-app/gradio/blob/main/gradio/external.py | |
| # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client | |
| def load_from_model(model_name: str, hf_token: str = None): | |
| import httpx | |
| import huggingface_hub | |
| from gradio.exceptions import ModelNotFoundError | |
| model_url = f"https://huggingface.co/{model_name}" | |
| api_url = f"https://api-inference.huggingface.co/models/{model_name}" | |
| print(f"Fetching model from: {model_url}") | |
| headers = {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {} | |
| response = httpx.request("GET", api_url, headers=headers) | |
| if response.status_code != 200: | |
| raise ModelNotFoundError( | |
| f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter." | |
| ) | |
| headers["X-Wait-For-Model"] = "true" | |
| client = huggingface_hub.InferenceClient(model=model_name, headers=headers, | |
| token=hf_token, timeout=server_timeout) | |
| inputs = [ | |
| gr.components.Textbox(render=False), | |
| gr.components.State(render=False), | |
| ] | |
| outputs = [ | |
| gr.components.Chatbot(render=False), | |
| gr.components.State(render=False), | |
| ] | |
| fn = client.chat_completion | |
| def query_huggingface_inference_endpoints(*data, **kwargs): | |
| return fn(*data, **kwargs) | |
| interface_info = { | |
| "fn": query_huggingface_inference_endpoints, | |
| "inputs": inputs, | |
| "outputs": outputs, | |
| "title": model_name, | |
| } | |
| return gr.Interface(**interface_info) | |
| def get_status(model_name: str): | |
| client = InferenceClient(timeout=10) | |
| return client.get_model_status(model_name) | |
| def load_clients(): | |
| global llm_clients | |
| for model in llm_serverless_models: | |
| status = get_status(model) | |
| #print(f"HF model status: {status}") | |
| if status is None or status.state not in ["Loadable", "Loaded"]: # | |
| print(f"Failed to load by serverless inference API: {model}. Model state is {status.state}") | |
| continue | |
| try: | |
| print(f"Fetching model by serverless inference API: {model}") | |
| llm_clients[model] = InferenceClient(model) | |
| except Exception as e: | |
| print(e) | |
| print(f"Failed to load by serverless inference API: {model}") | |
| continue | |
| print(f"Loaded by serverless inference API: {model}") | |
| for model in llm_models: | |
| if model in llm_clients.keys(): continue | |
| status = get_status(model) | |
| #print(f"HF model status: {status}") | |
| if status is None or status.state not in ["Loadable", "Loaded"]: # | |
| print(f"Failed to load: {model}. Model state is {status.state}") | |
| continue | |
| try: | |
| llm_clients[model] = load_from_model(model) | |
| except Exception as e: | |
| print(e) | |
| print(f"Failed to load: {model}") | |
| continue | |
| print(f"Loaded: {model}") | |
| def add_client(model_name: str): | |
| global llm_clients | |
| try: | |
| status = get_status(model_name) | |
| #print(f"HF model status: {status}") | |
| if status is None or status.state not in ["Loadable", "Loaded"]: # | |
| print(f"Failed to load: {model_name}. Model state is {status.state}") | |
| new_client = None | |
| else: new_client = InferenceClient(model_name) | |
| except Exception as e: | |
| print(e) | |
| new_client = None | |
| if new_client: | |
| print(f"Loaded by serverless inference API: {model_name}") | |
| llm_clients[model_name] = new_client | |
| return new_client | |
| else: | |
| print(f"Failed to load: {model_name}") | |
| return llm_clients.get(llm_serverless_models[0], None) | |
| def set_llm_model(model_name: str = llm_serverless_models[0]): | |
| global client_main | |
| global current_model | |
| if model_name in llm_clients.keys(): | |
| client_main = llm_clients.get(model_name, None) | |
| else: | |
| client_main = add_client(model_name) | |
| if client_main is not None: | |
| current_model = model_name | |
| print(f"Model selected: {model_name}") | |
| print(f"HF model status: {get_status(model_name)}") | |
| return model_name, get_llm_model_info(model_name) | |
| else: return None, "None" | |
| def get_llm_model(): | |
| return list(llm_clients.keys()) | |
| # Initialize inference clients | |
| load_clients() | |
| set_llm_model() | |
| client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") | |
| # https://huggingface.co/docs/huggingface_hub/v0.24.5/en/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion | |
| def chat_body(message, history, query, tokens, temperature, top_p, fpenalty, web_summary): | |
| system_prompt = get_llm_sysprompt() | |
| if query and web_summary: | |
| messages = [] | |
| messages.append({"role": "system", "content": system_prompt}) | |
| for msg in history: | |
| messages.append({"role": "user", "content": str(msg[0])}) | |
| messages.append({"role": "assistant", "content": str(msg[1])}) | |
| messages.append({"role": "user", "content": f"{message}\nweb_result\n{web_summary}"}) | |
| messages.append({"role": "assistant", "content": ""}) | |
| try: | |
| if isinstance(client_main, gr.Interface): | |
| stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature, | |
| top_p=top_p, frequency_penalty=fpenalty, stream=True) | |
| else: | |
| stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature, | |
| top_p=top_p, stream=True) | |
| except Exception as e: | |
| print(e) | |
| stream = [] | |
| output = "" | |
| for response in stream: | |
| if response and response.choices and response.choices[0].delta.content is not None: | |
| output += response.choices[0].delta.content | |
| yield [(output, None)] | |
| else: | |
| messages = [] | |
| messages.append({"role": "system", "content": system_prompt}) | |
| for msg in history: | |
| messages.append({"role": "user", "content": str(msg[0])}) | |
| messages.append({"role": "assistant", "content": str(msg[1])}) | |
| messages.append({"role": "user", "content": message}) | |
| messages.append({"role": "assistant", "content": ""}) | |
| try: | |
| if isinstance(client_main, gr.Interface): | |
| stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature, | |
| top_p=top_p, stream=True) | |
| else: | |
| stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature, | |
| top_p=top_p, stream=True) | |
| except Exception as e: | |
| print(e) | |
| stream = [] | |
| output = "" | |
| for response in stream: | |
| if response and response.choices and response.choices[0].delta.content is not None: | |
| output += response.choices[0].delta.content | |
| yield [(output, None)] | |
| def get_web_summary(history, query_message): | |
| if not query_message: return "" | |
| func_calls = [] | |
| functions_metadata = [ | |
| {"type": "function", "function": {"name": "web_search", "description": "Search query on Google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Web search query"}}, "required": ["query"]}}}, | |
| ] | |
| for msg in history: | |
| func_calls.append({"role": "user", "content": f"{str(msg[0])}"}) | |
| func_calls.append({"role": "assistant", "content": f"{str(msg[1])}"}) | |
| func_calls.append({"role": "user", "content": f'[SYSTEM] You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {query_message}'}) | |
| response = client_gemma.chat_completion(func_calls, max_tokens=200) | |
| response = str(response) | |
| try: | |
| response = response[int(response.find("{")):int(response.rindex("}"))+1] | |
| except: | |
| response = response[int(response.find("{")):(int(response.rfind("}"))+1)] | |
| response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '') | |
| #print(f"\n{response}") | |
| try: | |
| json_data = json.loads(str(response)) | |
| if json_data["name"] == "web_search": | |
| query = json_data["arguments"]["query"] | |
| #gr.Info("Searching Web") | |
| web_results = perform_search(query) | |
| #gr.Info("Extracting relevant Info") | |
| web_summary = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res['text']]) | |
| return web_summary | |
| else: | |
| return "" | |
| except: | |
| return "" | |
| # Function to handle responses | |
| def chat_response(message, history, query, tokens, temperature, top_p, fpenalty): | |
| if history is None: history = [] | |
| yield from chat_body(message, history, query, tokens, temperature, top_p, fpenalty, get_web_summary(history, query)) | |