captain-awesome
commited on
Commit
•
9d2bf07
1
Parent(s):
48e3505
Update app.py
Browse files
app.py
CHANGED
@@ -23,14 +23,16 @@ import torch
|
|
23 |
|
24 |
|
25 |
def get_vector_store_from_url(url):
|
26 |
-
model_name = "BAAI/bge-large-en"
|
27 |
-
model_kwargs = {'device': 'cpu'}
|
28 |
-
encode_kwargs = {'normalize_embeddings': False}
|
29 |
-
embeddings = HuggingFaceBgeEmbeddings(
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
)
|
|
|
|
|
34 |
|
35 |
loader = WebBaseLoader(url)
|
36 |
document = loader.load()
|
@@ -114,17 +116,23 @@ def get_response(user_input):
|
|
114 |
# lib="avx2", # for CPU
|
115 |
# )
|
116 |
|
117 |
-
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
118 |
-
# llm = HuggingFaceHub(
|
119 |
-
# repo_id=llm_model,
|
120 |
-
# model_kwargs={"temperature": 0.3, "max_new_tokens": 250, "top_k": 3}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
# )
|
122 |
|
123 |
-
llm =
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
device_map='auto'
|
128 |
)
|
129 |
retriever_chain = get_context_retriever_chain(st.session_state.vector_store,llm)
|
130 |
conversation_rag_chain = get_conversational_rag_chain(retriever_chain,llm)
|
|
|
23 |
|
24 |
|
25 |
def get_vector_store_from_url(url):
|
26 |
+
# model_name = "BAAI/bge-large-en"
|
27 |
+
# model_kwargs = {'device': 'cpu'}
|
28 |
+
# encode_kwargs = {'normalize_embeddings': False}
|
29 |
+
# embeddings = HuggingFaceBgeEmbeddings(
|
30 |
+
# model_name=model_name,
|
31 |
+
# model_kwargs=model_kwargs,
|
32 |
+
# encode_kwargs=encode_kwargs
|
33 |
+
# )
|
34 |
+
embeddings = HuggingFaceEmbeddings(model_name='thenlper/gte-large',
|
35 |
+
model_kwargs={'device': 'cpu'})
|
36 |
|
37 |
loader = WebBaseLoader(url)
|
38 |
document = loader.load()
|
|
|
116 |
# lib="avx2", # for CPU
|
117 |
# )
|
118 |
|
119 |
+
# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
120 |
+
# # llm = HuggingFaceHub(
|
121 |
+
# # repo_id=llm_model,
|
122 |
+
# # model_kwargs={"temperature": 0.3, "max_new_tokens": 250, "top_k": 3}
|
123 |
+
# # )
|
124 |
+
|
125 |
+
# llm = transformers.AutoModelForCausalLM.from_pretrained(
|
126 |
+
# model_name,
|
127 |
+
# trust_remote_code=True,
|
128 |
+
# torch_dtype=torch.bfloat16,
|
129 |
+
# device_map='auto'
|
130 |
# )
|
131 |
|
132 |
+
llm = HuggingFacePipeline.from_model_id(
|
133 |
+
model_id="google/flan-t5-base",
|
134 |
+
task="text2text-generation",
|
135 |
+
# model_kwargs={"temperature": 0.2},
|
|
|
136 |
)
|
137 |
retriever_chain = get_context_retriever_chain(st.session_state.vector_store,llm)
|
138 |
conversation_rag_chain = get_conversational_rag_chain(retriever_chain,llm)
|