Chandranshu Jain commited on
Commit
04eb61d
1 Parent(s): 214739d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -68,10 +68,19 @@ def get_conversational_chain():
68
  Answer:
69
  """
70
  #model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3, google_api_key=GOOGLE_API_KEY)
71
- model = ChatGoogleGenerativeAI(model="gemini-1.0-pro-latest", temperature=0.3, google_api_key=GOOGLE_API_KEY)
72
- prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
73
- chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
74
- return chain
 
 
 
 
 
 
 
 
 
75
 
76
  def embedding(chunk,query):
77
  #embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
@@ -80,10 +89,9 @@ def embedding(chunk,query):
80
  db = Chroma.from_documents(chunk,embeddings)
81
  doc = db.similarity_search(query)
82
  print(doc)
83
- chain = get_conversational_chain()
84
- response = chain({"input_documents": doc, "question": query}, return_only_outputs=True)
85
  print(response)
86
- return response["output_text"]
87
  #st.write("Reply: ", response["output_text"])
88
 
89
  if 'messages' not in st.session_state:
 
68
  Answer:
69
  """
70
  #model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3, google_api_key=GOOGLE_API_KEY)
71
+ repo_id='meta-llama/Meta-Llama-3-70B'
72
+ llm = HuggingFaceEndpoint(
73
+ repo_id=repo_id, max_length=512, temperature=0.5, token=userdata.get('HUGGING_FACE_API_KEY'))
74
+ pt = ChatPromptTemplate.from_template(template)
75
+ # Retrieve and generate using the relevant snippets of the blog.
76
+ retriever = db.as_retriever()
77
+ rag_chain = (
78
+ {"context": retriever, "question": RunnablePassthrough()}
79
+ | pt
80
+ | llm
81
+ | StrOutputParser()
82
+ )
83
+ return rag_chain
84
 
85
  def embedding(chunk,query):
86
  #embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
 
89
  db = Chroma.from_documents(chunk,embeddings)
90
  doc = db.similarity_search(query)
91
  print(doc)
92
+ response = rag_chain.invoke(query)
 
93
  print(response)
94
+ return response
95
  #st.write("Reply: ", response["output_text"])
96
 
97
  if 'messages' not in st.session_state: