Spaces:
Runtime error
Runtime error
from langchain.vectorstores import Qdrant | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
import qdrant_client | |
import os | |
from langchain.embeddings.huggingface import HuggingFaceInstructEmbeddings | |
from langchain.llms import HuggingFaceHub | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA | |
from langchain.llms import OpenAI | |
import streamlit as st | |
def get_vector_store(): | |
client = qdrant_client.QdrantClient( | |
os.getenv('QDRANT_HOST'), | |
api_key=os.getenv('QDRANT_API_KEY') | |
) | |
embeddings = HuggingFaceInstructEmbeddings(model_name = "hkunlp/instructor-xl") | |
# embeddings = OpenAIEmbeddings() | |
vectore_store = Qdrant( | |
client=client, | |
collection_name=os.getenv('QDRANT_COLLECTION_NAME'), | |
embeddings=embeddings, | |
) | |
return vectore_store | |
def main(): | |
load_dotenv() | |
st.set_page_config(page_title="Ask AI", page_icon=":robot:") | |
st.header("Ask your remote database") | |
vectorstore = get_vector_store() | |
#create chain | |
qa = RetrievalQA.from_chain_type( | |
# llm=OpenAI(temperature=0), | |
llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.7, "max_length":512}), | |
chain_type="stuff", | |
retriever=vectorstore.as_retriever(search_type="similarity") | |
) | |
user_question = st.text_input("Ask your question") | |
if user_question: | |
st.write(f"Question: {user_question}") | |
answer = qa.run(user_question) | |
st.write(f"Answer: {answer}") | |
if __name__ == "__main__": | |
main() |