from dotenv import load_dotenv import os import streamlit as st import requests from typing import List import json import socket from urllib3.connection import HTTPConnection from app import embed_documents, retrieve_documents API_BASE_URL = os.environ.get("API_BASE_URL") load_dotenv() embeddings_model_name = "all-MiniLM-L6-v2" persist_directory = "db" model = "tiiuae/falcon-7b-instruct" from constants import CHROMA_SETTINGS import chromadb def list_of_collections(): client = chromadb.Client(CHROMA_SETTINGS) return (client.list_collections()) def main(): st.title("PrivateGPT App: Document Embedding and Retrieval") # Document upload section st.header("Document Upload") files = st.file_uploader("Upload document", accept_multiple_files=True) # collection_name = st.text_input("Collection Name") not working for some reason if st.button("Embed"): embed_documents(files, "collection_name") # Query section st.header("Document Retrieval") collection_names = get_collection_names() selected_collection = st.selectbox("Select a document", collection_names) if selected_collection: query = st.text_input("Query") if st.button("Retrieve"): retrieve_documents(query, selected_collection) # def embed_documents(files:List[st.runtime.uploaded_file_manager.UploadedFile], collection_name:str): # endpoint = f"{API_BASE_URL}/embed" # files_data = [("files", file) for file in files] # data = {"collection_name": collection_name} # response = requests.post(endpoint, files=files_data, data=data) # if response.status_code == 200: # st.success("Documents embedded successfully!") # else: # st.error("Document embedding failed.") # st.write(response.text) def get_collection_names(): collections = list_of_collections() return [collection.name for collection in collections] # def retrieve_documents(query: str, collection_name: str): # endpoint = f"{API_BASE_URL}/retrieve" # data = {"query": query, "collection_name": collection_name} # # Modify socket options for the HTTPConnection class # HTTPConnection.default_socket_options = ( # HTTPConnection.default_socket_options + [ # (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), # (socket.SOL_TCP, socket.TCP_KEEPIDLE, 45), # (socket.SOL_TCP, socket.TCP_KEEPINTVL, 10), # (socket.SOL_TCP, socket.TCP_KEEPCNT, 6) # ] # ) # response = requests.post(endpoint, params=data) # if response.status_code == 200: # result = response.json() # st.subheader("Results") # st.text(result["results"]) # st.subheader("Documents") # for doc in result["docs"]: # st.text(doc) # else: # st.error("Failed to retrieve documents.") # st.write(response.text) if __name__ == "__main__": main()