|
import gradio as gr |
|
import os |
|
import torch |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.llms import HuggingFacePipeline |
|
from langchain.chains import ConversationChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
|
api_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
list_llm = [ |
|
"meta-llama/Meta-Llama-3-8B-Instruct", |
|
"mistralai/Mistral-7B-Instruct-v0.2", |
|
"deepseek-ai/deepseek-llm-7b-chat" |
|
] |
|
list_llm_simple = [os.path.basename(llm) for llm in list_llm] |
|
|
|
def load_doc(list_file_path): |
|
"""Load and split PDF documents into chunks""" |
|
loaders = [PyPDFLoader(x) for x in list_file_path] |
|
pages = [] |
|
for loader in loaders: |
|
pages.extend(loader.load()) |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1024, |
|
chunk_overlap=64 |
|
) |
|
doc_splits = text_splitter.split_documents(pages) |
|
return doc_splits |
|
|
|
def create_db(splits): |
|
"""Create vector database from document splits""" |
|
embeddings = HuggingFaceEmbeddings() |
|
vectordb = FAISS.from_documents(splits, embeddings) |
|
return vectordb |
|
|
|
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): |
|
"""Initialize the language model chain""" |
|
llm = HuggingFaceEndpoint( |
|
repo_id=llm_model, |
|
huggingfacehub_api_token=api_token, |
|
temperature=temperature, |
|
max_new_tokens=max_tokens, |
|
top_k=top_k, |
|
task="text-generation" |
|
) |
|
|
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
output_key='answer', |
|
return_messages=True |
|
) |
|
|
|
retriever = vector_db.as_retriever() |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever=retriever, |
|
chain_type="stuff", |
|
memory=memory, |
|
return_source_documents=True, |
|
verbose=False, |
|
) |
|
return qa_chain |
|
|
|
def initialize_database(list_file_obj, progress=gr.Progress()): |
|
"""Initialize the document database""" |
|
list_file_path = [x.name for x in list_file_obj if x is not None] |
|
doc_splits = load_doc(list_file_path) |
|
vector_db = create_db(doc_splits) |
|
return vector_db, "Database created successfully!" |
|
|
|
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): |
|
"""Initialize the Language Model""" |
|
llm_name = list_llm[llm_option] |
|
print("Selected LLM model:", llm_name) |
|
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) |
|
return qa_chain, "Analysis Assistant initialized and ready!" |
|
|
|
def format_chat_history(message, chat_history): |
|
"""Format chat history for the model""" |
|
formatted_chat_history = [] |
|
for user_message, bot_message in chat_history: |
|
formatted_chat_history.append(f"User: {user_message}") |
|
formatted_chat_history.append(f"Assistant: {bot_message}") |
|
return formatted_chat_history |
|
|
|
def conversation(qa_chain, message, history): |
|
"""Handle conversation and document analysis""" |
|
formatted_chat_history = format_chat_history(message, history) |
|
response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history}) |
|
response_answer = response["answer"] |
|
if response_answer.find("Helpful Answer:") != -1: |
|
response_answer = response_answer.split("Helpful Answer:")[-1] |
|
response_sources = response["source_documents"] |
|
response_source1 = response_sources[0].page_content.strip() |
|
response_source2 = response_sources[1].page_content.strip() |
|
response_source3 = response_sources[2].page_content.strip() |
|
response_source1_page = response_sources[0].metadata["page"] + 1 |
|
response_source2_page = response_sources[1].metadata["page"] + 1 |
|
response_source3_page = response_sources[2].metadata["page"] + 1 |
|
new_history = history + [(message, response_answer)] |
|
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page |
|
|
|
|
|
|
|
|
|
def demo(): |
|
"""Main demo application with enhanced layout""" |
|
theme = gr.themes.Default( |
|
primary_hue="indigo", |
|
secondary_hue="blue", |
|
neutral_hue="slate", |
|
) |
|
|
|
|
|
custom_css = """ |
|
#app-header { |
|
text-align: center; |
|
padding: 2rem; |
|
background: linear-gradient(to right, #1a365d, #2c5282); |
|
color: white; |
|
margin-bottom: 2rem; |
|
border-radius: 0 0 1rem 1rem; |
|
} |
|
#app-header h1 { |
|
font-size: 2.5rem; |
|
margin-bottom: 0.5rem; |
|
color: white; |
|
} |
|
#app-header p { |
|
font-size: 1.2rem; |
|
opacity: 0.9; |
|
} |
|
.container { |
|
max-width: 1400px; |
|
margin: 0 auto; |
|
padding: 0 1rem; |
|
} |
|
.features-grid { |
|
display: grid; |
|
grid-template-columns: repeat(2, 1fr); |
|
gap: 1rem; |
|
margin-bottom: 2rem; |
|
} |
|
.feature-card { |
|
background: #f8fafc; |
|
padding: 1.5rem; |
|
border-radius: 0.5rem; |
|
border: 1px solid #e2e8f0; |
|
} |
|
.section-title { |
|
font-size: 1.5rem; |
|
color: #1a365d; |
|
margin-bottom: 1rem; |
|
padding-bottom: 0.5rem; |
|
border-bottom: 2px solid #e2e8f0; |
|
} |
|
.control-panel { |
|
background: #f8fafc; |
|
padding: 1.5rem; |
|
border-radius: 0.5rem; |
|
margin-bottom: 1rem; |
|
} |
|
.chat-container { |
|
background: white; |
|
border-radius: 0.5rem; |
|
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); |
|
} |
|
.reference-panel { |
|
background: #f8fafc; |
|
padding: 1rem; |
|
border-radius: 0.5rem; |
|
margin-top: 1rem; |
|
} |
|
""" |
|
|
|
with gr.Blocks(theme=theme, css=custom_css) as demo: |
|
vector_db = gr.State() |
|
qa_chain = gr.State() |
|
|
|
|
|
with gr.Row(elem_id="app-header"): |
|
with gr.Column(): |
|
gr.HTML( |
|
""" |
|
<h1>MetroAssist AI</h1> |
|
<p>Expert System for Metrology Report Analysis</p> |
|
""" |
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
|
with gr.Column(scale=1): |
|
with gr.Group(visible=True) as control_panel: |
|
gr.Markdown("## Document Processing", elem_classes="section-title") |
|
|
|
|
|
with gr.Box(elem_classes="control-panel"): |
|
gr.Markdown("### π Upload Documents") |
|
document = gr.Files( |
|
label="Metrology Reports (PDF)", |
|
file_count="multiple", |
|
file_types=["pdf"], |
|
) |
|
db_btn = gr.Button("Process Documents", elem_classes="primary-btn") |
|
db_progress = gr.Textbox( |
|
value="Ready for documents", |
|
label="Processing Status", |
|
) |
|
|
|
|
|
with gr.Box(elem_classes="control-panel"): |
|
gr.Markdown("### π€ Model Configuration") |
|
llm_btn = gr.Radio( |
|
choices=list_llm_simple, |
|
label="Select AI Model", |
|
value=list_llm_simple[0], |
|
type="index" |
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
slider_temperature = gr.Slider( |
|
minimum=0.01, |
|
maximum=1.0, |
|
value=0.5, |
|
step=0.1, |
|
label="Analysis Precision" |
|
) |
|
slider_maxtokens = gr.Slider( |
|
minimum=128, |
|
maximum=9192, |
|
value=4096, |
|
step=128, |
|
label="Response Length" |
|
) |
|
slider_topk = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=3, |
|
step=1, |
|
label="Analysis Diversity" |
|
) |
|
|
|
qachain_btn = gr.Button("Initialize Assistant") |
|
llm_progress = gr.Textbox( |
|
value="Not initialized", |
|
label="Assistant Status" |
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
with gr.Group() as chat_interface: |
|
gr.Markdown("## Interactive Analysis", elem_classes="section-title") |
|
|
|
|
|
with gr.Row(equal_height=True) as feature_grid: |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
### π Capabilities |
|
- Calibration Analysis |
|
- Standards Compliance |
|
- Uncertainty Evaluation |
|
""" |
|
) |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
### π‘ Best Practices |
|
- Ask specific questions |
|
- Include measurement context |
|
- Specify standards |
|
""" |
|
) |
|
|
|
|
|
with gr.Box(elem_classes="chat-container"): |
|
chatbot = gr.Chatbot( |
|
height=400, |
|
label="Analysis Conversation" |
|
) |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="Ask about your metrology report...", |
|
label="Query", |
|
scale=4 |
|
) |
|
submit_btn = gr.Button("Send") |
|
clear_btn = gr.ClearButton([msg, chatbot], value="Clear") |
|
|
|
|
|
with gr.Accordion("Document References", open=False, elem_classes="reference-panel"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
doc_source1 = gr.Textbox(label="Reference 1", lines=2) |
|
source1_page = gr.Number(label="Page") |
|
with gr.Column(): |
|
doc_source2 = gr.Textbox(label="Reference 2", lines=2) |
|
source2_page = gr.Number(label="Page") |
|
with gr.Column(): |
|
doc_source3 = gr.Textbox(label="Reference 3", lines=2) |
|
source3_page = gr.Number(label="Page") |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown( |
|
""" |
|
--- |
|
### About MetroAssist AI |
|
|
|
A specialized tool for metrology professionals, providing advanced analysis |
|
of calibration certificates, measurement data, and technical standards compliance. |
|
|
|
**Version 1.0** | Β© 2024 MetroAssist AI |
|
""" |
|
) |
|
|
|
|
|
db_btn.click( |
|
initialize_database, |
|
inputs=[document], |
|
outputs=[vector_db, db_progress] |
|
) |
|
|
|
qachain_btn.click( |
|
initialize_LLM, |
|
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], |
|
outputs=[qa_chain, llm_progress] |
|
).then( |
|
lambda: [None, "", 0, "", 0, "", 0], |
|
inputs=None, |
|
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], |
|
queue=False |
|
) |
|
|
|
msg.submit( |
|
conversation, |
|
inputs=[qa_chain, msg, chatbot], |
|
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], |
|
queue=False |
|
) |
|
|
|
submit_btn.click( |
|
conversation, |
|
inputs=[qa_chain, msg, chatbot], |
|
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], |
|
queue=False |
|
) |
|
|
|
clear_btn.click( |
|
lambda: [None, "", 0, "", 0, "", 0], |
|
inputs=None, |
|
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], |
|
queue=False |
|
) |
|
|
|
demo.queue().launch(debug=True) |
|
|
|
if __name__ == "__main__": |
|
demo() |