Spaces:
Runtime error
Runtime error
# based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py | |
import collections | |
import os | |
from queue import Queue | |
from time import sleep | |
from typing import Any, Dict, List, Optional, Tuple | |
import gradio as gr | |
from anyio.from_thread import start_blocking_portal | |
from langchain import PromptTemplate | |
from langchain.callbacks.manager import AsyncCallbackManager | |
from langchain.chains import LLMChain | |
from langchain.chat_models import ChatOpenAI, PromptLayerChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING | |
from langchain.prompts.chat import (ChatPromptTemplate, | |
HumanMessagePromptTemplate) | |
from langchain.schema import HumanMessage | |
from langchain.vectorstores import Chroma | |
from langchain.docstore.document import Document | |
from util import SyncStreamingLLMCallbackHandler, CustomOpenAIEmbeddings | |
def I(x): | |
"Identity function; does nothing." | |
return x | |
class PreprocessingPromptTemplate(PromptTemplate): | |
arg_preprocessing: Dict = {} # this is probably the wrong type | |
def format(self, **kwargs: Any) -> str: | |
"""Format the prompt with the inputs. | |
Args: | |
kwargs: Any arguments to be passed to the prompt template. | |
Returns: | |
A formatted string. | |
Example: | |
.. code-block:: python | |
prompt.format(variable1="foo") | |
""" | |
kwargs = self._merge_partial_and_user_variables(**kwargs) | |
kwargs = self._preprocess_args(kwargs) | |
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) | |
def _preprocess_args(self, args: dict): | |
return {k: self.arg_preprocessing.get(k, I)(v) for k, v in args.items()} | |
def top_results_to_string(x: List[Tuple[Document, float]]): | |
return "\n~~~\n".join(f"Result {i} Title: {doc.metadata['title']}\nResult {i} Content: {doc.page_content}" for i, (doc, score) in enumerate(x, 1)) | |
PROMPT = """You are a helpful AI assistant that summarizes search results for users. | |
--- | |
A user has searched for the following query: | |
{query} | |
--- | |
The search engine returned the following 5 search results: | |
{top_results} | |
--- | |
Based on the search results, answer the user's query, and use the same language as the user's query. | |
Say which search result you used. | |
Do not use information other than the search results. | |
Say 'No answer found.' if there are no relevant results. | |
Afterwards, say how confident you are in your answer as a percentage. | |
""" | |
PROMPT_TEMPLATE = PreprocessingPromptTemplate(template=PROMPT, input_variables=['query', 'top_results']) | |
PROMPT_TEMPLATE.arg_preprocessing['top_results'] = top_results_to_string | |
# TODO give relevance value in prompt | |
# TODO ask gpt to say which sources it used | |
# TODO azure? | |
COLLECTION = Chroma( | |
embedding_function=CustomOpenAIEmbeddings(api_key=os.environ.get("OPENAI_API_KEY", None)), | |
persist_directory="./.chroma", | |
collection_name="CUHK", | |
) | |
# COLLECTION = CHROMA_CLIENT.get_collection(name='CUHK') | |
def load_chain(api_type): | |
shared_args = { | |
"temperature": 0, | |
"model_name": "gpt-3.5-turbo", | |
"pl_tags": ["cuhk-demo"], | |
"streaming": True, | |
} | |
if api_type == "OpenAI": | |
chat = PromptLayerChatOpenAI( | |
**shared_args, | |
api_key = os.environ.get("OPENAI_API_KEY", None), | |
) | |
elif api_type == "Azure OpenAI": | |
chat = PromptLayerChatOpenAI( | |
api_type = "azure", | |
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None), | |
api_base = os.environ.get("AZURE_OPENAI_API_BASE", None), | |
api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"), | |
engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None), | |
**shared_args | |
) | |
chain = chain = LLMChain(llm=chat, prompt=PROMPT_TEMPLATE) | |
return chat, chain | |
def initialize_chain(api_type): | |
"Runs at app start" | |
chat, chain = load_chain(api_type) | |
return chat, chain | |
def change_chain(api_type, old_chain): | |
chat, chain = load_chain(api_type) | |
return chat, chain | |
def find_top_results(query): | |
results = COLLECTION.similarity_search_with_score(query, k=4) # TODO filter by device (windows, mac, android, ios) | |
output = "\n".join(f"1. [{d.metadata['title']}]({d.metadata['url']}) <small>(dist: {s})</small>" for d, s in results) | |
return results, output | |
def ask_gpt(chain, query, top_results): # top_results: List[Tuple[Document, float]] | |
q = Queue() | |
job_done = object() | |
def task(): | |
chain.run( | |
query=query, | |
top_results=top_results, | |
callbacks=[SyncStreamingLLMCallbackHandler(q)], | |
) | |
q.put(job_done) | |
return | |
with start_blocking_portal() as portal: | |
portal.start_task_soon(task) | |
content = "" | |
while True: | |
next_token = q.get(True, timeout=15) | |
if next_token is job_done: | |
break | |
content += next_token | |
yield content | |
demo = gr.Blocks(css=""" | |
#sidebar { | |
max-width: 300px; | |
} | |
""") | |
with demo: | |
with gr.Row(): | |
# sidebar | |
with gr.Column(elem_id="sidebar"): | |
api_type = gr.Radio( | |
["OpenAI", "Azure OpenAI"], | |
value="OpenAI", | |
label="Server", | |
info="You can try changing this if responses are slow." | |
) | |
# main | |
with gr.Column(): | |
# Company img | |
gr.HTML(r'<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>') | |
chat = gr.State() | |
chain = gr.State() | |
query = gr.Textbox(label="Search Query:") | |
top_results_data = gr.State() | |
top_results = gr.Markdown(label="Search Results") | |
response = gr.Textbox(label="AI Response") | |
load_event = demo.load(initialize_chain, [api_type], [chat, chain]) | |
query_event = query.submit(find_top_results, [query], [top_results_data, top_results]) | |
ask_event = query_event.then(ask_gpt, [chain, query, top_results_data], [response]) | |
api_type.change(change_chain, | |
[api_type, chain], | |
[chat, chain], | |
cancels=[load_event, query_event, ask_event]) | |
demo.queue() | |
if __name__ == "__main__": | |
demo.launch() | |