Hyeonseo commited on
Commit
a54ee03
1 Parent(s): 3768e5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -21
app.py CHANGED
@@ -1,22 +1,70 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- client = InferenceClient(model="http://127.0.0.1:8080")
5
-
6
- def inference(message, history):
7
- partial_message = ""
8
- for token in client.text_generation(message, max_new_tokens=20, stream=True):
9
- partial_message += token
10
- yield partial_message
11
-
12
- gr.ChatInterface(
13
- inference,
14
- chatbot=gr.Chatbot(height=300),
15
- textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
16
- description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
17
- title="Gradio 🤝 TGI",
18
- examples=["Are tomatoes vegetables?"],
19
- retry_btn="Retry",
20
- undo_btn="Undo",
21
- clear_btn="Clear",
22
- ).queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ from langchain_community.llms import HuggingFaceTextGenInference
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain.chains import RetrievalQA
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.vectorstores import Chroma
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+
11
+ # Assuming you have the necessary setup for userdata
12
+ HF_TOKEN = os.environ['MY_HF_TOKEN']
13
+ ENDPOINT_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf"
14
+
15
+ # Setup for the document loader and retriever
16
+ loader = PyPDFLoader("2023_법정감염병진단_신고기준.pdf")
17
+ pages = loader.load_and_split()
18
+ disease_pages = pages[54:72]
19
+
20
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True)
21
+ splits = text_splitter.split_documents(disease_pages)
22
+
23
+ modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
24
+ embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs={'device':'cpu'}, encode_kwargs={'normalize_embeddings': False})
25
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
26
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
27
+
28
+ # Setup for the language model
29
+ llm = HuggingFaceTextGenInference(
30
+ inference_server_url=ENDPOINT_URL,
31
+ max_new_tokens=1024,
32
+ top_k=50,
33
+ temperature=0.1,
34
+ repetition_penalty=1.03,
35
+ server_kwargs={
36
+ "headers": {
37
+ "Authorization": f"Bearer {HF_TOKEN}",
38
+ "Content-Type": "application/json",
39
+ }
40
+ },
41
+ )
42
+
43
+ # Template for the question-answering
44
+ template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
45
+ {context}
46
+ Question: {question}
47
+ Helpful Answer:"""
48
+ QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
49
+
50
+ def predict(message, history):
51
+ question = message
52
+ context = "" # Add context if
53
+
54
+ # Create a RetrievalQA instance
55
+ chain = RetrievalQA.from_chain_type(
56
+ llm=llm,
57
+ chain_type="stuff",
58
+ retriever=retriever,
59
+ return_source_documents=True,
60
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
61
+ )
62
+
63
+ # Execute the query
64
+ result = chain({"query": question})
65
+
66
+ # Stream the response
67
+ partial_message = ""
68
+ for chunk in result['result']:
69
+ partial_message += chunk
70
+ yield partial_message