import gradio as gr import os import nest_asyncio import re from pathlib import Path import typing as t import base64 from mimetypes import guess_type from llama_parse import LlamaParse from llama_index.core.schema import TextNode from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage, Settings from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI from llama_index.core.query_engine import CustomQueryEngine from llama_index.multi_modal_llms.openai import OpenAIMultiModal from llama_index.core.prompts import PromptTemplate from llama_index.core.schema import ImageNode from llama_index.core.base.response.schema import Response from typing import Any, List, Optional from llama_index.core.postprocessor.types import BaseNodePostprocessor nest_asyncio.apply() # Setting API keys os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') os.environ["LLAMA_CLOUD_API_KEY"] = os.getenv('LLAMA_CLOUD_API_KEY') # Initialize the parser parser = LlamaParse( result_type="markdown", parsing_instruction="You are given a medical textbook on medicine", use_vendor_multimodal_model=True, vendor_multimodal_model_name="gpt-4o-mini-2024-07-18", show_progress=True, verbose=True, invalidate_cache=True, do_not_cache=True, num_workers=8, language="en" ) # Function to encode image to data URL def local_image_to_data_url(image_path): mime_type, _ = guess_type(image_path) if mime_type is None: mime_type = 'image/png' with open(image_path, "rb") as image_file: base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8') return f"data:{mime_type};base64,{base64_encoded_data}" # Function to get sorted image files def get_page_number(file_name): match = re.search(r"-page-(\d+)\.jpg$", str(file_name)) if match: return int(match.group(1)) return 0 def _get_sorted_image_files(image_dir): raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()] sorted_files = sorted(raw_files, key=get_page_number) return sorted_files def get_text_nodes(md_json_objs, image_dir) -> t.List[TextNode]: nodes = [] for result in md_json_objs: json_dicts = result["pages"] document_name = result["file_path"].split('/')[-1] docs = [doc["md"] for doc in json_dicts] image_files = _get_sorted_image_files(image_dir) for idx, doc in enumerate(docs): node = TextNode( text=doc, metadata={"image_path": str(image_files[idx]), "page_num": idx + 1, "document_name": document_name}, ) nodes.append(node) return nodes # Gradio interface functions def upload_and_process_file(uploaded_file): if uploaded_file is None: return "Please upload a medical textbook (pdf)" file_path = f"{uploaded_file.name}" with open(file_path, "wb") as f: f.write(uploaded_file.read()) md_json_objs = parser.get_json_result([file_path]) image_dicts = parser.get_images(md_json_objs, download_path="data_images") return md_json_objs def ask_question(md_json_objs, query_text, uploaded_query_image=None): if not md_json_objs: return "No knowledge base loaded. Please upload a file first." text_nodes = get_text_nodes(md_json_objs, "data_images") # Setup index and LLM embed_model = OpenAIEmbedding(model="text-embedding-3-large") llm = OpenAI("gpt-4o-mini-2024-07-18") Settings.llm = llm Settings.embed_model = embed_model if not os.path.exists("storage_manuals"): index = VectorStoreIndex(text_nodes, embed_model=embed_model) index.storage_context.persist(persist_dir="./storage_manuals") else: ctx = StorageContext.from_defaults(persist_dir="./storage_manuals") index = load_index_from_storage(ctx) retriever = index.as_retriever() # Encode query image if provided encoded_image_url = None if uploaded_query_image is not None: query_image_path = f"{uploaded_query_image.name}" with open(query_image_path, "wb") as img_file: img_file.write(uploaded_query_image.read()) encoded_image_url = local_image_to_data_url(query_image_path) # Setup query engine QA_PROMPT_TMPL = """ You are a friendly medical chatbot designed to assist users by providing accurate and detailed responses to medical questions based on information from medical books. ### Context: --------------------- {context_str} --------------------- ### Query Text: {query_str} ### Query Image: --------------------- {encoded_image_url} --------------------- ### Answer: """ QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL) gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18") class MultimodalQueryEngine(CustomQueryEngine): qa_prompt: PromptTemplate retriever: BaseRetriever multi_modal_llm: OpenAIMultiModal node_postprocessors: Optional[List[BaseNodePostprocessor]] def __init__( self, qa_prompt: PromptTemplate, retriever: BaseRetriever, multi_modal_llm: OpenAIMultiModal, node_postprocessors: Optional[List[BaseNodePostprocessor]] = [], ): super().__init__( qa_prompt=qa_prompt, retriever=retriever, multi_modal_llm=multi_modal_llm, node_postprocessors=node_postprocessors ) def custom_query(self, query_str: str): # retrieve most relevant nodes nodes = self.retriever.retrieve(query_str) # create image nodes from the image associated with those nodes image_nodes = [ NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) for n in nodes ] # create context string from parsed markdown text ctx_str = "\n\n".join( [r.node.get_content(metadata_mode=MetadataMode.LLM).strip() for r in nodes] ) # prompt for the LLM fmt_prompt = self.qa_prompt.format( context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url ) # use the multimodal LLM to interpret images and generate a response to the prompt llm_response = self.multi_modal_llm.complete( prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes], ) return Response( response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": nodes, "image_nodes": image_nodes}, ) query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm) response = query_engine.custom_query(query_text) return response.response # Define Gradio interface md_json_objs = [] def upload_wrapper(uploaded_file): global md_json_objs md_json_objs = upload_and_process_file(uploaded_file) return "File successfully processed!" iface = gr.Interface( fn=ask_question, inputs=[ gr.inputs.State(), gr.inputs.Textbox(label="Enter your query:"), gr.inputs.File(label="Upload a query image (if any):", optional=True) ], outputs="text", title="Medical Knowledge Base & Query System" ) upload_iface = gr.Interface( fn=upload_wrapper, inputs=gr.inputs.File(label="Upload a medical textbook (pdf):"), outputs="text", title="Upload Knowledge Base" ) app = gr.TabbedInterface([upload_iface, iface], ["Upload Knowledge Base", "Ask a Question"]) app.launch()