File size: 4,422 Bytes
3a6e888
d12b486
901aafa
205bcc4
901aafa
f7ab30c
901aafa
 
 
 
 
f7ab30c
901aafa
 
 
 
 
 
f7ab30c
e083c3d
 
901aafa
 
 
f040926
901aafa
f040926
 
901aafa
ceffafe
4b10f41
 
901aafa
4b10f41
 
5d883cb
4b10f41
 
5d883cb
4b10f41
 
5d883cb
4b10f41
5d883cb
901aafa
4b10f41
901aafa
4b10f41
 
901aafa
4b10f41
5d883cb
4b10f41
 
901aafa
 
 
4b10f41
901aafa
 
 
 
 
 
 
 
 
 
 
12fb790
901aafa
 
 
 
4b10f41
 
 
901aafa
 
 
4b10f41
5d883cb
 
 
 
 
901aafa
f040926
 
4b10f41
901aafa
 
4b10f41
901aafa
 
4b10f41
901aafa
4b10f41
901aafa
 
 
 
 
 
 
5d883cb
901aafa
4b10f41
b7fc734
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import time

import langchain
# loaders
from langchain.document_loaders import PyPDFLoader, OnlinePDFLoader, Docx2txtLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader
# splits
from langchain.text_splitter import RecursiveCharacterTextSplitter
# embeddings
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
# vector stores
from langchain.vectorstores import Chroma
# huggingface hub
from huggingface_hub import InferenceClient
from langchain import HuggingFaceHub
# models
from langchain.llms import OpenAI
# retrievers
from langchain.chains import RetrievalQA
import gradio as gr

HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]

def build_context(openai_key,files,urls):
  if openai_key != "":
    embeddings = OpenAIEmbeddings(model_name="text-embedding-ada-002", openai_api_key=openai_key)
  else:
    embeddings = HuggingFaceEmbeddings(
                model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}
            )
  documents = []
  if files is not None:
    print("files not none")
    for idx, file in enumerate(files):
      if file.name.endswith('.pdf'):
        loader = PyPDFLoader(file.name)
        documents.extend(loader.load())
      elif file.name.endswith('.docx'):
        loader = Docx2txtLoader(file.name)
        documents.extend(loader.load())
      elif file.name.endswith('.ppt') or file.name.endswith('.pptx'):
        loader = UnstructuredPowerPointLoader(file.name)
        documents.extend(loader.load())
  if urls != "":
    print("urls not none")
    list_urls = urls.split(sep=",")
    for url in list_urls:
      loader = OnlinePDFLoader(url)
      documents.extend(loader.load())
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=800,chunk_overlap=0,length_function=len,separators=["\n\n", "\n", " ", ""])
  chunked_documents = text_splitter.split_documents(documents)
  global vectordb
  vectordb = Chroma.from_documents(
    documents=chunked_documents,
    embedding=embeddings,
    )
  return "loaded"

def llm_response(openai_key, message, chat_history):
  if openai_key != "":
    llm = OpenAI(
        temperature=0, openai_api_key=openai_key, model_name="gpt-3.5-turbo", verbose=False
    )
  else:
    llm = HuggingFaceHub(repo_id='MBZUAI/LaMini-Flan-T5-248M',
                        huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
                         model_kwargs={"max_length":512,"do_sample":True,
                                       "temperature":0.2})
  qa_chain = RetrievalQA.from_chain_type(llm = llm,
                                       chain_type = "stuff",
                                       retriever = vectordb.as_retriever(search_kwargs = {"k": 10}), 
                                       return_source_documents = False,
                                       verbose = True)
  result = qa_chain(message)["result"]
  chat_history.append((message, result))
  time.sleep(2)
  return "", chat_history
  
def loading():
    return "Loading..."

def clear_chromadb():
  ids = vectordb.get()["ids"]
  for id in ids:
    vectordb._collection.delete(ids=id)

with gr.Blocks(theme=gr.themes.Soft()) as demo:
  with gr.Row():
    openai_key = gr.Textbox(label="Enter your OpenAI API Key if you want to use the gpt-3.5-turbo-16k model. If not, the open source LaMini-Flan-T5-248M is used")
  with gr.Row():
      pdf_docs = gr.Files(label="Load pdf files", file_types=['.pdf','.docx','.ppt','.pptx'], type="file")
      urls = gr.Textbox(label="Enter one of multiple online pdf urls (comma separated if multiple)")
  with gr.Row():
    load_docs = gr.Button("Load documents and urls", variant="primary", scale=1)
    loading_status = gr.Textbox(label="Loading status", placeholder="", interactive=False, scale=0)
  with gr.Row():
    with gr.Column(scale=1):
      msg = gr.Textbox(label="User message")
      chatbot = gr.Chatbot()
  with gr.Row():
      clearchat = gr.ClearButton([msg, chatbot], value="New chat",)
      cleardb = gr.Button(value="Reset context (for loading new documents)", variant="secondary")
  load_docs.click(loading, None, loading_status, queue=False)
  load_docs.click(build_context, inputs=[openai_key,pdf_docs, urls], outputs=[loading_status], queue=False)
  msg.submit(llm_response, [openai_key, msg, chatbot], [msg, chatbot])
  cleardb.click(clear_chromadb)
    
demo.queue(concurrency_count=3)
demo.launch()