jarif commited on
Commit
a144f48
·
verified ·
1 Parent(s): 575af7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import streamlit as st
2
  import os
3
  import logging
 
 
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from langchain_community.vectorstores import Chroma
7
- from langchain_community.llms import HuggingFacePipeline
8
  from langchain.chains import RetrievalQA
9
- from ingest import create_chroma_db
10
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO)
@@ -29,29 +29,24 @@ def load_llm():
29
  )
30
  return HuggingFacePipeline(pipeline=pipe)
31
 
32
- def load_chroma_db():
33
- chroma_dir = "chroma_db"
34
- if not os.path.exists(chroma_dir):
35
- st.warning("Chroma database not found. Creating a new one...")
36
- create_chroma_db()
37
-
38
- if not os.path.exists(chroma_dir):
39
- st.error("Failed to create the Chroma database. Please check the 'docs' directory and try again.")
40
- raise RuntimeError("Chroma database creation failed.")
41
-
42
  try:
43
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
44
- db = Chroma.load_local(chroma_dir, embeddings)
45
- logger.info(f"Chroma database loaded successfully from {chroma_dir}")
46
- return db.as_retriever()
47
  except Exception as e:
48
- st.error(f"Failed to load Chroma database: {e}")
49
- logger.exception("Exception in load_chroma_db")
50
  raise
51
 
52
  def process_answer(instruction):
53
  try:
54
- retriever = load_chroma_db()
55
  llm = load_llm()
56
  qa = RetrievalQA.from_chain_type(
57
  llm=llm,
 
1
  import streamlit as st
2
  import os
3
  import logging
4
+ import faiss
5
+ import numpy as np
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
8
  from langchain.chains import RetrievalQA
9
+ from langchain.vectorstores import FAISS
10
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO)
 
29
  )
30
  return HuggingFacePipeline(pipeline=pipe)
31
 
32
+ def load_faiss_index():
33
+ index_path = "faiss_index.index"
34
+ if not os.path.exists(index_path):
35
+ st.warning("FAISS index not found. Please create the index first.")
36
+ raise RuntimeError("FAISS index not found.")
37
+
 
 
 
 
38
  try:
39
+ faiss_index = faiss.read_index(index_path)
40
+ logger.info(f"FAISS index loaded successfully from {index_path}")
41
+ return FAISS(faiss_index)
 
42
  except Exception as e:
43
+ st.error(f"Failed to load FAISS index: {e}")
44
+ logger.exception("Exception in load_faiss_index")
45
  raise
46
 
47
  def process_answer(instruction):
48
  try:
49
+ retriever = load_faiss_index()
50
  llm = load_llm()
51
  qa = RetrievalQA.from_chain_type(
52
  llm=llm,