captain-awesome commited on
Commit
49b8bd6
·
verified ·
1 Parent(s): b594097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -11,13 +11,12 @@ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
11
  from langchain_community.llms import CTransformers
12
  from ctransformers import AutoModelForCausalLM
13
  from langchain.llms import HuggingFaceHub
14
- from transformers import AutoTokenizer
15
  import os
16
  # from dotenv import load_dotenv
17
 
18
  # load_dotenv()
19
 
20
- os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.getenv("HF_KEY")
21
 
22
 
23
  def get_vector_store_from_url(url):
@@ -85,11 +84,18 @@ def get_response(user_input):
85
  # lib="avx2", # for CPU
86
  # )
87
 
88
- llm_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
89
- llm = HuggingFaceHub(
90
- repo_id=llm_model,
91
- model_kwargs={"temperature": 0.3, "max_new_tokens": 250, "top_k": 3}
92
- )
 
 
 
 
 
 
 
93
  retriever_chain = get_context_retriever_chain(st.session_state.vector_store,llm)
94
  conversation_rag_chain = get_conversational_rag_chain(retriever_chain,llm)
95
 
 
11
  from langchain_community.llms import CTransformers
12
  from ctransformers import AutoModelForCausalLM
13
  from langchain.llms import HuggingFaceHub
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
  import os
16
  # from dotenv import load_dotenv
17
 
18
  # load_dotenv()
19
 
 
20
 
21
 
22
  def get_vector_store_from_url(url):
 
84
  # lib="avx2", # for CPU
85
  # )
86
 
87
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
88
+ # llm = HuggingFaceHub(
89
+ # repo_id=llm_model,
90
+ # model_kwargs={"temperature": 0.3, "max_new_tokens": 250, "top_k": 3}
91
+ # )
92
+
93
+ model = transformers.AutoModelForCausalLM.from_pretrained(
94
+ model_name,
95
+ trust_remote_code=True,
96
+ torch_dtype=torch.bfloat16,
97
+ device_map='auto'
98
+
99
  retriever_chain = get_context_retriever_chain(st.session_state.vector_store,llm)
100
  conversation_rag_chain = get_conversational_rag_chain(retriever_chain,llm)
101