File size: 3,038 Bytes
a54ee03
db44f18
a54ee03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef8f30
a54ee03
 
 
27fd964
 
 
 
 
 
 
 
a54ee03
27fd964
 
a54ee03
27fd964
 
 
 
 
 
80bdab0
 
 
 
d9cd49d
80bdab0
 
d9cd49d
80bdab0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
from langchain_community.llms import HuggingFaceTextGenInference
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings

# Assuming you have the necessary setup for userdata
HF_TOKEN = os.environ['MY_HF_TOKEN']
ENDPOINT_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf"

# Setup for the document loader and retriever
loader = PyPDFLoader("2023_법정감염병진단_신고기준.pdf")
pages = loader.load_and_split()
disease_pages = pages[54:72]

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True)
splits = text_splitter.split_documents(disease_pages)

modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs={'device':'cpu'}, encode_kwargs={'normalize_embeddings': False})
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

# Setup for the language model
llm = HuggingFaceTextGenInference(
    inference_server_url=ENDPOINT_URL,
    max_new_tokens=1024,
    top_k=50,
    temperature=0.1,
    repetition_penalty=1.03,
    server_kwargs={
        "headers": {
            "Authorization": f"Bearer {HF_TOKEN}",
            "Content-Type": "application/json",
        }
    },
)

# Template for the question-answering
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

def predict(message):
    question = message
    context = ""  # Add context if
    
    # Create a RetrievalQA instance
    chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
        chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
    )

    # Execute the query
    result = chain({"query": question})

    # Stream the response
    partial_message = ""
    for chunk in result['result']:
        partial_message += chunk
        yield partial_message
        
iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(placeholder="Chat with me!", label="Your Message"),
    outputs=gr.Text(label="Response"),
    live=False,
    title="Infectious-Disease-Diagnosis-Chatbot",
    description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
    examples=[["발열과 구토 증상이 있는데, 어떤 감염병이야?"]],
    theme="default"  # You can choose a theme that fits your UI preference
)

iface.launch()