File size: 5,171 Bytes
2ad079a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a720a1d
 
 
 
2ad079a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a720a1d
 
2ad079a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import os

from langchain_openai import OpenAIEmbeddings
from langchain_postgres.vectorstores import PGVector
from langchain_openai import ChatOpenAI
from langchain.schema import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain

import qdrant_client
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core import SimpleDirectoryReader
from llama_index.core.indices.multi_modal.base import MultiModalVectorStoreIndex
from llama_index.multi_modal_llms.openai import OpenAIMultiModal

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
chat_llm = ChatOpenAI(temperature = 0.5, model = 'gpt-4-turbo')

contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)

qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \

context: {context}"""
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
question_answer_chain = create_stuff_documents_chain(chat_llm, qa_prompt)

pg_password = os.getenv("PG_PASSWORD")
aws_ec2_ip = os.getenv("AWS_EC2_IP")
pg_connection = f"postgresql+psycopg://postgres:{pg_password}@{aws_ec2_ip}:5432/postgres"
qd_client = qdrant_client.QdrantClient(path="qdrant_db")
image_store = QdrantVectorStore(client=qd_client, collection_name="image_collection")
storage_context = StorageContext.from_defaults(image_store=image_store)
openai_mm_llm = OpenAIMultiModal(model="gpt-4o", max_new_tokens=1500)

def response(message, history, doc_label):

    text_store = PGVector(collection_name=doc_label,
                          embeddings=embeddings,
                          connection=pg_connection)
    retriever = text_store.as_retriever()
    history_aware_retriever = create_history_aware_retriever(chat_llm, 
                                                             retriever, 
                                                             contextualize_q_prompt)
    rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

    response = rag_chain.invoke({"input": message, "chat_history": chat_history})
    chat_history.extend([HumanMessage(content=message), response["answer"]])
    
    return response["answer"]

def img_retrieve(query, doc_label):
    doc_imgs = SimpleDirectoryReader(f"./{doc_label}").load_data()
    index = MultiModalVectorStoreIndex.from_documents(doc_imgs, 
                                                      storage_context=storage_context)
    img_query_engine = index.as_query_engine(llm=openai_mm_llm,
                                             image_similarity_top_k=3)
    response_mm = img_query_engine.query(query)
    retrieved_imgs = [n.metadata["file_path"] for n in response_mm.metadata["image_nodes"]]
    return retrieved_imgs

chat_history = []

with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    with gr.Row():
        gr.Markdown(
    """
    # 🎨 Multi-modal RAG Chatbot
    """)
    with gr.Row():
        gr.Markdown("""Select document from the menu, and interact with the text and images in the document.
                    """)
    with gr.Row():
        with gr.Column(scale=2):       
            doc_label = gr.Dropdown(["LLaVA", "Interior"], label="Select a document:")
            chatbot = gr.ChatInterface(fn=response, additional_inputs=[doc_label], fill_height=True)
        with gr.Column(scale=1):
            sample_1 = "https://i.pinimg.com/originals/e3/44/d7/e344d7631cd515edd36cc6930deaedec.jpg"
            sample_2 = "https://live.staticflickr.com/5307/5765340890_e386f42a99_b.jpg"
            sample_3 = "https://blog.kakaocdn.net/dn/nqcUB/btrzYjTgjWl/jFFlIBrdkoKv4jbSyZbiEk/img.jpg"
            gallery = gr.Gallery(label="Retrieved images", 
                                 show_label=True, preview=True, 
                                 object_fit="contain", 
                                 value=[(sample_1, 'sample_1'),
                                        (sample_2, 'sample_2'), 
                                        (sample_3, 'sample_3')])
            query = gr.Textbox(label="Enter query")
            button = gr.Button(value="Retrieve images")
            button.click(img_retrieve, [query, doc_label], gallery)

demo.launch(share=True)