dkdaniz commited on
Commit
efd2848
β€’
1 Parent(s): 5375f78

Update frontend.py

Browse files
Files changed (1) hide show
  1. frontend.py +67 -118
frontend.py CHANGED
@@ -1,122 +1,71 @@
1
- import torch
2
- import subprocess
3
- import streamlit as st
4
- from run_localGPT import load_model
5
- from langchain.vectorstores import Chroma
6
- from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
7
- from langchain.embeddings import HuggingFaceInstructEmbeddings
8
- from langchain.chains import RetrievalQA
9
- from streamlit_extras.add_vertical_space import add_vertical_space
10
- from langchain.prompts import PromptTemplate
11
- from langchain.memory import ConversationBufferMemory
12
-
13
-
14
- def model_memory():
15
- # Adding history to the model.
16
- template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
17
- just say that you don't know, don't try to make up an answer.
18
-
19
- {context}
20
-
21
- {history}
22
- Question: {question}
23
- Helpful Answer:"""
24
-
25
- prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
26
- memory = ConversationBufferMemory(input_key="question", memory_key="history")
27
-
28
- return prompt, memory
29
-
30
-
31
- # Sidebar contents
32
- with st.sidebar:
33
- st.title("πŸ€—πŸ’¬ Converse with your Data")
34
- st.markdown(
35
- """
36
- ## About
37
- This app is an LLM-powered chatbot built using:
38
- - [Streamlit](https://streamlit.io/)
39
- - [LangChain](https://python.langchain.com/)
40
- - [LocalGPT](https://github.com/PromtEngineer/localGPT)
41
-
42
- """
43
- )
44
- add_vertical_space(5)
45
- st.write("Made with ❀️ by [Prompt Engineer](https://youtube.com/@engineerprompt)")
46
-
47
-
48
- if torch.backends.mps.is_available():
49
- DEVICE_TYPE = "mps"
50
- elif torch.cuda.is_available():
51
- DEVICE_TYPE = "cuda"
52
- else:
53
- DEVICE_TYPE = "cpu"
54
-
55
-
56
- # if "result" not in st.session_state:
57
- # # Run the document ingestion process.
58
- # run_langest_commands = ["python", "ingest.py"]
59
- # run_langest_commands.append("--device_type")
60
- # run_langest_commands.append(DEVICE_TYPE)
61
-
62
- # result = subprocess.run(run_langest_commands, capture_output=True)
63
- # st.session_state.result = result
64
-
65
- # Define the retreiver
66
- # load the vectorstore
67
- if "EMBEDDINGS" not in st.session_state:
68
- EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
69
- st.session_state.EMBEDDINGS = EMBEDDINGS
70
-
71
- if "DB" not in st.session_state:
72
- DB = Chroma(
73
- persist_directory=PERSIST_DIRECTORY,
74
- embedding_function=st.session_state.EMBEDDINGS,
75
- client_settings=CHROMA_SETTINGS,
76
  )
77
- st.session_state.DB = DB
78
-
79
- if "RETRIEVER" not in st.session_state:
80
- RETRIEVER = DB.as_retriever()
81
- st.session_state.RETRIEVER = RETRIEVER
82
 
83
- if "LLM" not in st.session_state:
84
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
85
- st.session_state["LLM"] = LLM
86
 
87
-
88
- if "QA" not in st.session_state:
89
- prompt, memory = model_memory()
90
-
91
- QA = RetrievalQA.from_chain_type(
92
- llm=LLM,
93
- chain_type="stuff",
94
- retriever=RETRIEVER,
95
- return_source_documents=True,
96
- chain_type_kwargs={"prompt": prompt, "memory": memory},
97
  )
98
- st.session_state["QA"] = QA
99
-
100
- st.title("LocalGPT App πŸ’¬")
101
- # Create a text input box for the user
102
- prompt = st.text_input("Input your prompt here")
103
- # while True:
104
-
105
- # If the user hits enter
106
- if prompt:
107
- # Then pass the prompt to the LLM
108
- response = st.session_state["QA"](prompt)
109
- answer, docs = response["result"], response["source_documents"]
110
- # ...and write it out to the screen
111
- st.write(answer)
112
-
113
- # With a streamlit expander
114
- with st.expander("Document Similarity Search"):
115
- # Find the relevant pages
116
- search = st.session_state.DB.similarity_search_with_score(prompt)
117
- # Write out the first
118
- for i, doc in enumerate(search):
119
- # print(doc)
120
- st.write(f"Source Document # {i+1} : {doc[0].metadata['source'].split('/')[-1]}")
121
- st.write(doc[0].page_content)
122
- st.write("--------------------------------")
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import tempfile
5
+
6
+ import requests
7
+ from flask import Flask, render_template, request
8
+ from werkzeug.utils import secure_filename
9
+
10
+ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
11
+
12
+ app = Flask(__name__)
13
+ app.secret_key = "LeafmanZSecretKey"
14
+
15
+ API_HOST = "https://dkdaniz-katara.hf.space:5110/api"
16
+
17
+ # PAGES #
18
+ @app.route("/", methods=["GET", "POST"])
19
+ def home_page():
20
+ if request.method == "POST":
21
+ if "user_prompt" in request.form:
22
+ user_prompt = request.form["user_prompt"]
23
+ print(f"User Prompt: {user_prompt}")
24
+
25
+ main_prompt_url = f"{API_HOST}/prompt_route"
26
+ response = requests.post(main_prompt_url, data={"user_prompt": user_prompt})
27
+ print(response.status_code) # print HTTP response status code for debugging
28
+ if response.status_code == 200:
29
+ # print(response.json()) # Print the JSON data from the response
30
+ return render_template("home.html", show_response_modal=True, response_dict=response.json())
31
+ elif "documents" in request.files:
32
+ delete_source_url = f"{API_HOST}/delete_source" # URL of the /api/delete_source endpoint
33
+ if request.form.get("action") == "reset":
34
+ response = requests.get(delete_source_url)
35
+
36
+ save_document_url = f"{API_HOST}/save_document"
37
+ run_ingest_url = f"{API_HOST}/run_ingest" # URL of the /api/run_ingest endpoint
38
+ files = request.files.getlist("documents")
39
+ for file in files:
40
+ print(file.filename)
41
+ filename = secure_filename(file.filename)
42
+ with tempfile.SpooledTemporaryFile() as f:
43
+ f.write(file.read())
44
+ f.seek(0)
45
+ response = requests.post(save_document_url, files={"document": (filename, f)})
46
+ print(response.status_code) # print HTTP response status code for debugging
47
+ # Make a GET request to the /api/run_ingest endpoint
48
+ response = requests.get(run_ingest_url)
49
+ print(response.status_code) # print HTTP response status code for debugging
50
+
51
+ # Display the form for GET request
52
+ return render_template(
53
+ "home.html",
54
+ show_response_modal=False,
55
+ response_dict={"Prompt": "None", "Answer": "None", "Sources": [("ewf", "wef")]},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
 
 
 
 
 
57
 
 
 
 
58
 
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--port", type=int, default=5111, help="Port to run the UI on. Defaults to 5111.")
62
+ parser.add_argument(
63
+ "--host",
64
+ type=str,
65
+ default="0.0.0.0",
66
+ help="Host to run the UI on. Defaults to 127.0.0.1. "
67
+ "Set to 0.0.0.0 to make the UI externally "
68
+ "accessible from other devices.",
69
  )
70
+ args = parser.parse_args()
71
+ app.run(debug=False, host=args.host, port=args.port)