Hyeonseo's picture
Update app.py
d9cd49d verified
import os
import gradio as gr
from langchain_community.llms import HuggingFaceTextGenInference
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
# Assuming you have the necessary setup for userdata
HF_TOKEN = os.environ['MY_HF_TOKEN']
ENDPOINT_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf"
# Setup for the document loader and retriever
loader = PyPDFLoader("2023_법정감염병진단_신고기준.pdf")
pages = loader.load_and_split()
disease_pages = pages[54:72]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True)
splits = text_splitter.split_documents(disease_pages)
modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs={'device':'cpu'}, encode_kwargs={'normalize_embeddings': False})
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Setup for the language model
llm = HuggingFaceTextGenInference(
inference_server_url=ENDPOINT_URL,
max_new_tokens=1024,
top_k=50,
temperature=0.1,
repetition_penalty=1.03,
server_kwargs={
"headers": {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
}
},
)
# Template for the question-answering
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
def predict(message):
question = message
context = "" # Add context if
# Create a RetrievalQA instance
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
# Execute the query
result = chain({"query": question})
# Stream the response
partial_message = ""
for chunk in result['result']:
partial_message += chunk
yield partial_message
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(placeholder="Chat with me!", label="Your Message"),
outputs=gr.Text(label="Response"),
live=False,
title="Infectious-Disease-Diagnosis-Chatbot",
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
examples=[["발열과 구토 증상이 있는데, 어떤 감염병이야?"]],
theme="default" # You can choose a theme that fits your UI preference
)
iface.launch()