Spaces:
Runtime error
Runtime error
| # based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py | |
| import collections | |
| import os | |
| from queue import Queue | |
| from time import sleep | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| from anyio.from_thread import start_blocking_portal | |
| from langchain import PromptTemplate | |
| from langchain.callbacks.manager import AsyncCallbackManager | |
| from langchain.chains import LLMChain | |
| from langchain.chat_models import ChatOpenAI, PromptLayerChatOpenAI | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING | |
| from langchain.prompts.chat import (ChatPromptTemplate, | |
| HumanMessagePromptTemplate) | |
| from langchain.schema import HumanMessage | |
| from langchain.vectorstores import Chroma | |
| from langchain.docstore.document import Document | |
| from util import SyncStreamingLLMCallbackHandler, CustomOpenAIEmbeddings | |
| def I(x): | |
| "Identity function; does nothing." | |
| return x | |
| class PreprocessingPromptTemplate(PromptTemplate): | |
| arg_preprocessing: Dict = {} # this is probably the wrong type | |
| def format(self, **kwargs: Any) -> str: | |
| """Format the prompt with the inputs. | |
| Args: | |
| kwargs: Any arguments to be passed to the prompt template. | |
| Returns: | |
| A formatted string. | |
| Example: | |
| .. code-block:: python | |
| prompt.format(variable1="foo") | |
| """ | |
| kwargs = self._merge_partial_and_user_variables(**kwargs) | |
| kwargs = self._preprocess_args(kwargs) | |
| return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) | |
| def _preprocess_args(self, args: dict): | |
| return {k: self.arg_preprocessing.get(k, I)(v) for k, v in args.items()} | |
| def top_results_to_string(x: List[Tuple[Document, float]]): | |
| return "\n~~~\n".join(f"Result {i} Title: {doc.metadata['title']}\nResult {i} Content: {doc.page_content}" for i, (doc, score) in enumerate(x, 1)) | |
| PROMPT = """You are a helpful AI assistant that summarizes search results for users. | |
| --- | |
| A user has searched for the following query: | |
| {query} | |
| --- | |
| The search engine returned the following 5 search results: | |
| {top_results} | |
| --- | |
| Based on the search results, answer the user's query, and use the same language as the user's query. | |
| Say which search result you used. | |
| Do not use information other than the search results. | |
| Say 'No answer found.' if there are no relevant results. | |
| Afterwards, say how confident you are in your answer as a percentage. | |
| """ | |
| PROMPT_TEMPLATE = PreprocessingPromptTemplate(template=PROMPT, input_variables=['query', 'top_results']) | |
| PROMPT_TEMPLATE.arg_preprocessing['top_results'] = top_results_to_string | |
| # TODO give relevance value in prompt | |
| # TODO ask gpt to say which sources it used | |
| # TODO azure? | |
| COLLECTION = Chroma( | |
| embedding_function=CustomOpenAIEmbeddings(api_key=os.environ.get("OPENAI_API_KEY", None)), | |
| persist_directory="./.chroma", | |
| collection_name="CUHK", | |
| ) | |
| # COLLECTION = CHROMA_CLIENT.get_collection(name='CUHK') | |
| def load_chain(api_type): | |
| shared_args = { | |
| "temperature": 0, | |
| "model_name": "gpt-3.5-turbo", | |
| "pl_tags": ["cuhk-demo"], | |
| "streaming": True, | |
| } | |
| if api_type == "OpenAI": | |
| chat = PromptLayerChatOpenAI( | |
| **shared_args, | |
| api_key = os.environ.get("OPENAI_API_KEY", None), | |
| ) | |
| elif api_type == "Azure OpenAI": | |
| chat = PromptLayerChatOpenAI( | |
| api_type = "azure", | |
| api_key = os.environ.get("AZURE_OPENAI_API_KEY", None), | |
| api_base = os.environ.get("AZURE_OPENAI_API_BASE", None), | |
| api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"), | |
| engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None), | |
| **shared_args | |
| ) | |
| chain = chain = LLMChain(llm=chat, prompt=PROMPT_TEMPLATE) | |
| return chat, chain | |
| def initialize_chain(api_type): | |
| "Runs at app start" | |
| chat, chain = load_chain(api_type) | |
| return chat, chain | |
| def change_chain(api_type, old_chain): | |
| chat, chain = load_chain(api_type) | |
| return chat, chain | |
| def find_top_results(query): | |
| results = COLLECTION.similarity_search_with_score(query, k=4) # TODO filter by device (windows, mac, android, ios) | |
| output = "\n".join(f"1. [{d.metadata['title']}]({d.metadata['url']}) <small>(dist: {s})</small>" for d, s in results) | |
| return results, output | |
| def ask_gpt(chain, query, top_results): # top_results: List[Tuple[Document, float]] | |
| q = Queue() | |
| job_done = object() | |
| def task(): | |
| chain.run( | |
| query=query, | |
| top_results=top_results, | |
| callbacks=[SyncStreamingLLMCallbackHandler(q)], | |
| ) | |
| q.put(job_done) | |
| return | |
| with start_blocking_portal() as portal: | |
| portal.start_task_soon(task) | |
| content = "" | |
| while True: | |
| next_token = q.get(True, timeout=15) | |
| if next_token is job_done: | |
| break | |
| content += next_token | |
| yield content | |
| demo = gr.Blocks(css=""" | |
| #sidebar { | |
| max-width: 300px; | |
| } | |
| """) | |
| with demo: | |
| with gr.Row(): | |
| # sidebar | |
| with gr.Column(elem_id="sidebar"): | |
| api_type = gr.Radio( | |
| ["OpenAI", "Azure OpenAI"], | |
| value="OpenAI", | |
| label="Server", | |
| info="You can try changing this if responses are slow." | |
| ) | |
| # main | |
| with gr.Column(): | |
| # Company img | |
| gr.HTML(r'<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>') | |
| chat = gr.State() | |
| chain = gr.State() | |
| query = gr.Textbox(label="Search Query:") | |
| top_results_data = gr.State() | |
| top_results = gr.Markdown(label="Search Results") | |
| response = gr.Textbox(label="AI Response") | |
| load_event = demo.load(initialize_chain, [api_type], [chat, chain]) | |
| query_event = query.submit(find_top_results, [query], [top_results_data, top_results]) | |
| ask_event = query_event.then(ask_gpt, [chain, query, top_results_data], [response]) | |
| api_type.change(change_chain, | |
| [api_type, chain], | |
| [chat, chain], | |
| cancels=[load_event, query_event, ask_event]) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |