Carlosito16 commited on
Commit
5ef51e2
1 Parent(s): e29cf0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -4
app.py CHANGED
@@ -46,6 +46,7 @@ def load_scraped_web_info():
46
 
47
 
48
 
 
49
  @st.cache_resource
50
  def load_embedding_model():
51
  embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
@@ -57,6 +58,31 @@ def load_faiss_index():
57
  vector_database = FAISS.load_local("faiss_index", embedding_model)
58
  return vector_database
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  #--------------
62
 
@@ -65,19 +91,36 @@ def load_faiss_index():
65
  load_scraped_web_info()
66
  embedding_model = load_embedding_model()
67
  vector_database = load_faiss_index()
68
- print("load done")
 
69
 
70
 
 
 
71
 
72
 
73
  query_input = st.text_input(label= 'your question')
 
 
74
  def retrieve_document(query_input):
75
  related_doc = vector_database.similarity_search(query_input)
76
  return related_doc
77
 
78
- output = st.text_area(label = "Here is the relevant documents",
 
 
 
 
79
  value = retrieve_document(query_input))
80
 
81
 
82
- faiss_retriever = vector_database.as_retriever()
83
- print("Succesfully had FAISS as retriever")
 
 
 
 
 
 
 
 
 
46
 
47
 
48
 
49
+
50
  @st.cache_resource
51
  def load_embedding_model():
52
  embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
 
58
  vector_database = FAISS.load_local("faiss_index", embedding_model)
59
  return vector_database
60
 
61
+ @st.cache_resource
62
+ def load_llm_model():
63
+ # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
64
+ # task= 'text2text-generation',
65
+ # model_kwargs={ "device_map": "auto",
66
+ # "load_in_8bit": True,"max_length": 256, "temperature": 0,
67
+ # "repetition_penalty": 1.5})
68
+
69
+
70
+ llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
71
+ task= 'text2text-generation',
72
+
73
+ model_kwargs={ "max_length": 256, "temperature": 0,
74
+ "torch_dtype":torch.float32,
75
+ "repetition_penalty": 1.3})
76
+
77
+
78
+ return llm
79
+
80
+
81
+ def load_retriever(llm, db):
82
+ qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
83
+ retriever=db.as_retriever())
84
+
85
+ return qa_retriever
86
 
87
  #--------------
88
 
 
91
  load_scraped_web_info()
92
  embedding_model = load_embedding_model()
93
  vector_database = load_faiss_index()
94
+ llm_model = load_llm_model()
95
+ qa_retriever = load_retriever(llm= llm_model, db= vector_database)
96
 
97
 
98
+ print("all load done")
99
+
100
 
101
 
102
  query_input = st.text_input(label= 'your question')
103
+
104
+
105
  def retrieve_document(query_input):
106
  related_doc = vector_database.similarity_search(query_input)
107
  return related_doc
108
 
109
+ def retrieve_answer(query_input):
110
+ answer = qa_retriever.run(query_input)
111
+ return answer
112
+
113
+ output_1 = st.text_area(label = "Here is the relevant documents",
114
  value = retrieve_document(query_input))
115
 
116
 
117
+ output_2 = st.text_area(label = "Here is the answer",
118
+ value = retrieve_answer(query_input))
119
+
120
+
121
+
122
+ # faiss_retriever = vector_database.as_retriever()
123
+ # print("Succesfully had FAISS as retriever")
124
+
125
+
126
+