ztor2 commited on
Commit
2ad079a
1 Parent(s): 2699e77

Upload app_open.py

Browse files
Files changed (1) hide show
  1. app_open.py +112 -0
app_open.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ from langchain_openai import OpenAIEmbeddings
5
+ from langchain_postgres.vectorstores import PGVector
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain.schema import HumanMessage
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+ from langchain.chains import create_history_aware_retriever
10
+ from langchain.chains import create_retrieval_chain
11
+ from langchain.chains.combine_documents import create_stuff_documents_chain
12
+
13
+ import qdrant_client
14
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
15
+ from llama_index.core import VectorStoreIndex, StorageContext
16
+ from llama_index.core import SimpleDirectoryReader
17
+ from llama_index.core.indices.multi_modal.base import MultiModalVectorStoreIndex
18
+ from llama_index.multi_modal_llms.openai import OpenAIMultiModal
19
+
20
+ os.environ["OPENAI_API_KEY"] = "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"
21
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
22
+ chat_llm = ChatOpenAI(temperature = 0.5, model = 'gpt-4-turbo')
23
+
24
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
25
+ which might reference context in the chat history, formulate a standalone question \
26
+ which can be understood without the chat history. Do NOT answer the question, \
27
+ just reformulate it if needed and otherwise return it as is."""
28
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
29
+ [
30
+ ("system", contextualize_q_system_prompt),
31
+ MessagesPlaceholder("chat_history"),
32
+ ("human", "{input}"),
33
+ ]
34
+ )
35
+
36
+ qa_system_prompt = """You are an assistant for question-answering tasks. \
37
+ Use the following pieces of retrieved context to answer the question. \
38
+ If you don't know the answer, just say that you don't know. \
39
+
40
+ context: {context}"""
41
+ qa_prompt = ChatPromptTemplate.from_messages(
42
+ [
43
+ ("system", qa_system_prompt),
44
+ MessagesPlaceholder("chat_history"),
45
+ ("human", "{input}"),
46
+ ]
47
+ )
48
+ question_answer_chain = create_stuff_documents_chain(chat_llm, qa_prompt)
49
+
50
+ # pg_connection = "postgresql+psycopg://postgres:3434@localhost:5433/mmrag"
51
+ pg_connection = "postgresql+psycopg://postgres:3434@43.201.34.191:5432/postgres"
52
+ qd_client = qdrant_client.QdrantClient(path="qdrant_db")
53
+ image_store = QdrantVectorStore(client=qd_client, collection_name="image_collection")
54
+ storage_context = StorageContext.from_defaults(image_store=image_store)
55
+ openai_mm_llm = OpenAIMultiModal(model="gpt-4o", max_new_tokens=1500)
56
+
57
+ def response(message, history, doc_label):
58
+
59
+ text_store = PGVector(collection_name=doc_label,
60
+ embeddings=embeddings,
61
+ connection=pg_connection)
62
+ retriever = text_store.as_retriever()
63
+ history_aware_retriever = create_history_aware_retriever(chat_llm,
64
+ retriever,
65
+ contextualize_q_prompt)
66
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
67
+
68
+ response = rag_chain.invoke({"input": message, "chat_history": chat_history})
69
+ chat_history.extend([HumanMessage(content=message), response["answer"]])
70
+
71
+ return response["answer"]
72
+
73
+ def img_retrieve(query, doc_label):
74
+ doc_imgs = SimpleDirectoryReader(f"./{doc_label}").load_data()
75
+ index = MultiModalVectorStoreIndex.from_documents(doc_imgs,
76
+ storage_context=storage_context)
77
+ img_query_engine = index.as_query_engine(llm=openai_mm_llm,
78
+ image_similarity_top_k=3)
79
+ response_mm = img_query_engine.query(query)
80
+ retrieved_imgs = [n.metadata["file_path"] for n in response_mm.metadata["image_nodes"]]
81
+ return retrieved_imgs
82
+
83
+ chat_history = []
84
+
85
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
86
+ with gr.Row():
87
+ gr.Markdown(
88
+ """
89
+ # 🎨 Multi-modal RAG Chatbot
90
+ """)
91
+ with gr.Row():
92
+ gr.Markdown("""Select document from the menu, and interact with the text and images in the document.
93
+ """)
94
+ with gr.Row():
95
+ with gr.Column(scale=2):
96
+ doc_label = gr.Dropdown(["LLaVA", "Interior"], label="Select a document:")
97
+ chatbot = gr.ChatInterface(fn=response, additional_inputs=[doc_label], fill_height=True)
98
+ with gr.Column(scale=1):
99
+ sample_1 = "https://i.ytimg.com/vi/bLj_mR4Fnls/maxresdefault.jpg"
100
+ sample_2 = "https://i.ytimg.com/vi/bOJdHU99OO8/maxresdefault.jpg"
101
+ sample_3 = "https://blog.kakaocdn.net/dn/nqcUB/btrzYjTgjWl/jFFlIBrdkoKv4jbSyZbiEk/img.jpg"
102
+ gallery = gr.Gallery(label="Retrieved images",
103
+ show_label=True, preview=True,
104
+ object_fit="contain",
105
+ value=[(sample_1, 'sample_1'),
106
+ (sample_2, 'sample_2'),
107
+ (sample_3, 'sample_3')])
108
+ query = gr.Textbox(label="Enter query")
109
+ button = gr.Button(value="Retrieve images")
110
+ button.click(img_retrieve, [query, doc_label], gallery)
111
+
112
+ demo.launch(share=True)