|
import os |
|
from bs4 import BeautifulSoup |
|
from llama_index.core import Document |
|
from llama_index.core import Settings |
|
from llama_index.core import SimpleDirectoryReader |
|
from llama_index.core import StorageContext |
|
from llama_index.core import VectorStoreIndex |
|
from llama_index.readers.web import SimpleWebPageReader |
|
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
|
|
import chromadb |
|
import re |
|
from llama_index.llms.gemini import Gemini |
|
from llama_index.embeddings.gemini import GeminiEmbedding |
|
|
|
from llama_index.core import PromptTemplate |
|
from llama_index.core.llms import ChatMessage |
|
import gradio as gr |
|
import uuid |
|
|
|
api_key = os.environ.get("API_KEY") |
|
|
|
llm = Gemini(api_key=api_key, model_name="models/gemini-1.5-flash-latest") |
|
gemini_embedding_model = GeminiEmbedding(api_key=api_key, model_name="models/embedding-001") |
|
|
|
|
|
|
|
|
|
|
|
Settings.llm = llm |
|
Settings.embed_model = gemini_embedding_model |
|
|
|
|
|
|
|
def extract_web(url): |
|
web_documents = SimpleWebPageReader().load_data( |
|
[url] |
|
) |
|
html_content = web_documents[0].text |
|
|
|
soup = BeautifulSoup(html_content, 'html.parser') |
|
p_tags = soup.findAll('p') |
|
text_content = "" |
|
for each in p_tags: |
|
text_content += each.text + "\n" |
|
|
|
|
|
documents = [Document(text=text_content)] |
|
option = "web" |
|
return documents, option |
|
|
|
def extract_doc(path): |
|
documents = SimpleDirectoryReader(input_files=path).load_data() |
|
option = "doc" |
|
return documents, option |
|
|
|
|
|
def create_col(documents): |
|
|
|
db_path = f'database/{str(uuid.uuid4()[:4])}' |
|
client = chromadb.PersistentClient(path=db_path) |
|
chroma_collection = client.get_or_create_collection("quickstart") |
|
|
|
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
|
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
|
|
VectorStoreIndex.from_documents( |
|
documents, storage_context=storage_context |
|
) |
|
return db_path |
|
|
|
def infer(message:str, history: list): |
|
print(f'message: {message}') |
|
print(f'history: {history}') |
|
messages = [] |
|
files_list = message["files"] |
|
|
|
|
|
for prompt,answer in history: |
|
if prompt is tuple: |
|
files_list += prompt[0] |
|
else: |
|
messages.append(ChatMessage(role= "user", content = prompt)) |
|
messages.append(ChatMessage(role= "assistant", content = answer)) |
|
|
|
|
|
if files_list: |
|
documents, option = extract_doc(files_list) |
|
else: |
|
if message["text"].startswith("http://") or message["text"].startswith("https://"): |
|
documents, option = extract_web(message["text"]) |
|
elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0: |
|
gr.Error("Please input an url or upload file at first.") |
|
|
|
|
|
print(documents) |
|
db_path = create_col(documents) |
|
|
|
|
|
load_client = chromadb.PersistentClient(path=db_path) |
|
|
|
|
|
chroma_collection = load_client.get_collection("quickstart") |
|
|
|
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store |
|
) |
|
|
|
|
|
template = ( |
|
""" You are an assistant for question-answering tasks. |
|
Use the following context to answer the question. |
|
If you don't know the answer, just say that you don't know. |
|
Use five sentences maximum and keep the answer concise.\n |
|
Question: {query_str} \nContext: {context_str} \nAnswer:""" |
|
) |
|
llm_prompt = PromptTemplate(template) |
|
print(llm_prompt) |
|
|
|
if option == "web" and len(history) == 0: |
|
response = "Get the web data! You can ask it." |
|
else: |
|
question = message['text'] |
|
query_engine = index.as_query_engine(text_qa_template=llm_prompt) |
|
response = query_engine.query(question) |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot() |
|
|
|
with gr.Blocks(theme="soft") as demo: |
|
gr.ChatInterface( |
|
fn = infer, |
|
title = "RAG demo", |
|
multimodal = True, |
|
chatbot=chatbot, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(api_open=False).launch(show_api=False, share=False) |