Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Implemented async for performance gain
Browse files- app/api/userchat.py +1 -1
- app/main.py +4 -3
- app/utils/chat_rag.py +37 -34
    	
        app/api/userchat.py
    CHANGED
    
    | @@ -11,7 +11,7 @@ async def chat_with_llama(user_input: str = Body(..., embed=True), current_user: | |
| 11 | 
             
                # Example logic for model inference (pseudo-code, adjust as necessary)
         | 
| 12 | 
             
                try:
         | 
| 13 | 
             
                    user_id = current_user["user_id"]
         | 
| 14 | 
            -
                    model_response = llm_infer(user_collection_name=sanitize_collection_name(user_id), prompt=user_input)
         | 
| 15 | 
             
                    # Optionally, store chat history
         | 
| 16 | 
             
                    # chromadb_face_helper.store_chat_history(user_id=current_user["user_id"], user_input=user_input, model_response=model_response)
         | 
| 17 | 
             
                except Exception as e:
         | 
|  | |
| 11 | 
             
                # Example logic for model inference (pseudo-code, adjust as necessary)
         | 
| 12 | 
             
                try:
         | 
| 13 | 
             
                    user_id = current_user["user_id"]
         | 
| 14 | 
            +
                    model_response = await llm_infer(user_collection_name=sanitize_collection_name(user_id), prompt=user_input)
         | 
| 15 | 
             
                    # Optionally, store chat history
         | 
| 16 | 
             
                    # chromadb_face_helper.store_chat_history(user_id=current_user["user_id"], user_input=user_input, model_response=model_response)
         | 
| 17 | 
             
                except Exception as e:
         | 
    	
        app/main.py
    CHANGED
    
    | @@ -11,7 +11,7 @@ from admin import admin_functions as admin | |
| 11 | 
             
            from utils.db import UserFaceEmbeddingFunction,ChromaDBFaceHelper
         | 
| 12 | 
             
            from api import userlogin, userlogout, userchat, userupload
         | 
| 13 | 
             
            from utils.db import ChromaDBFaceHelper
         | 
| 14 | 
            -
            from utils.chat_rag import  | 
| 15 |  | 
| 16 | 
             
            app = FastAPI()
         | 
| 17 |  | 
| @@ -42,8 +42,9 @@ async def startup_event(): | |
| 42 | 
             
                chromadb_face_helper = ChromaDBFaceHelper(db_path) # Used by APIs
         | 
| 43 |  | 
| 44 | 
             
                # Perform any other startup tasks here
         | 
| 45 | 
            -
                #  | 
| 46 | 
            -
                 | 
|  | |
| 47 |  | 
| 48 | 
             
                print(f"MODEL_PATH in main.py = {os.getenv('MODEL_PATH')} ")
         | 
| 49 |  | 
|  | |
| 11 | 
             
            from utils.db import UserFaceEmbeddingFunction,ChromaDBFaceHelper
         | 
| 12 | 
             
            from api import userlogin, userlogout, userchat, userupload
         | 
| 13 | 
             
            from utils.db import ChromaDBFaceHelper
         | 
| 14 | 
            +
            from utils.chat_rag import LlamaModelSingleton
         | 
| 15 |  | 
| 16 | 
             
            app = FastAPI()
         | 
| 17 |  | 
|  | |
| 42 | 
             
                chromadb_face_helper = ChromaDBFaceHelper(db_path) # Used by APIs
         | 
| 43 |  | 
| 44 | 
             
                # Perform any other startup tasks here
         | 
| 45 | 
            +
                # Preload the LLM model
         | 
| 46 | 
            +
                await LlamaModelSingleton.get_instance()
         | 
| 47 | 
            +
                print("LLM model loaded and ready.")
         | 
| 48 |  | 
| 49 | 
             
                print(f"MODEL_PATH in main.py = {os.getenv('MODEL_PATH')} ")
         | 
| 50 |  | 
    	
        app/utils/chat_rag.py
    CHANGED
    
    | @@ -2,6 +2,7 @@ | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            import re
         | 
| 4 | 
             
            import hashlib
         | 
|  | |
| 5 |  | 
| 6 | 
             
            from langchain.document_loaders import PyPDFLoader
         | 
| 7 |  | 
| @@ -47,7 +48,7 @@ def sanitize_collection_name(email): | |
| 47 |  | 
| 48 |  | 
| 49 | 
             
            # Modify vectordb initialization to be dynamic based on user_id
         | 
| 50 | 
            -
            def get_vectordb_for_user(user_collection_name):
         | 
| 51 | 
             
                # Get Chromadb location
         | 
| 52 | 
             
                CHROMADB_LOC = os.getenv('CHROMADB_LOC')
         | 
| 53 |  | 
| @@ -60,9 +61,9 @@ def get_vectordb_for_user(user_collection_name): | |
| 60 |  | 
| 61 | 
             
            vectordb_cache = {}
         | 
| 62 |  | 
| 63 | 
            -
            def get_vectordb_for_user_cached(user_collection_name):
         | 
| 64 | 
             
                if user_collection_name not in vectordb_cache:
         | 
| 65 | 
            -
                    vectordb_cache[user_collection_name] = get_vectordb_for_user(user_collection_name)
         | 
| 66 | 
             
                return vectordb_cache[user_collection_name]
         | 
| 67 |  | 
| 68 |  | 
| @@ -93,42 +94,44 @@ def pdf_to_vec(filename, user_collection_name): | |
| 93 | 
             
                return(vectordb)
         | 
| 94 | 
             
                #return collection  # Return the collection as the asset
         | 
| 95 |  | 
|  | |
|  | |
| 96 | 
             
            class LlamaModelSingleton:
         | 
| 97 | 
             
                _instance = None
         | 
| 98 |  | 
| 99 | 
            -
                 | 
|  | |
| 100 | 
             
                    if cls._instance is None:
         | 
| 101 | 
            -
                         | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
| 122 | 
            -
                return LlamaModelSingleton()
         | 
| 123 |  | 
| 124 |  | 
| 125 |  | 
| 126 | 
             
            #step 5, to instantiate once to create default_chain,router_chain,destination_chains into chain and set vectordb. so will not re-create per prompt
         | 
| 127 | 
            -
            def default_chain(llm, user_collection_name):
         | 
| 128 | 
             
                # Get Chromadb location
         | 
| 129 | 
             
                CHROMADB_LOC = os.getenv('CHROMADB_LOC')
         | 
| 130 |  | 
| 131 | 
            -
                vectordb = get_vectordb_for_user_cached(user_collection_name)  # Use the dynamic vectordb based on user_id
         | 
| 132 | 
             
                sum_template = """
         | 
| 133 | 
             
                As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
         | 
| 134 |  | 
| @@ -209,13 +212,13 @@ def default_chain(llm, user_collection_name): | |
| 209 | 
             
                return default_chain,router_chain,destination_chains
         | 
| 210 |  | 
| 211 | 
             
            # Adjust llm_infer to accept user_id and use it for user-specific processing
         | 
| 212 | 
            -
            def llm_infer(user_collection_name, prompt):
         | 
| 213 |  | 
| 214 | 
            -
                llm = load_llm()  # load_llm is singleton for entire system
         | 
| 215 |  | 
| 216 | 
            -
                vectordb = get_vectordb_for_user_cached(user_collection_name) # Vector collection for each us. 
         | 
| 217 |  | 
| 218 | 
            -
                default_chain, router_chain, destination_chains = get_or_create_chain(user_collection_name, llm)  # Now user-specific
         | 
| 219 |  | 
| 220 | 
             
                chain = MultiPromptChain(
         | 
| 221 | 
             
                    router_chain=router_chain,
         | 
| @@ -231,13 +234,13 @@ def llm_infer(user_collection_name, prompt): | |
| 231 | 
             
            # Assuming a simplified caching mechanism for demonstration
         | 
| 232 | 
             
            chain_cache = {}
         | 
| 233 |  | 
| 234 | 
            -
            def get_or_create_chain(user_collection_name, llm):
         | 
| 235 | 
             
                if 'default_chain' in chain_cache and 'router_chain' in chain_cache:
         | 
| 236 | 
             
                    default_chain = chain_cache['default_chain']
         | 
| 237 | 
             
                    router_chain = chain_cache['router_chain']
         | 
| 238 | 
             
                    destination_chains = chain_cache['destination_chains']
         | 
| 239 | 
             
                else:
         | 
| 240 | 
            -
                    vectordb = get_vectordb_for_user_cached(user_collection_name)  # User-specific vector database
         | 
| 241 | 
             
                    sum_template = """
         | 
| 242 | 
             
                    As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
         | 
| 243 |  | 
|  | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            import re
         | 
| 4 | 
             
            import hashlib
         | 
| 5 | 
            +
            import asyncio
         | 
| 6 |  | 
| 7 | 
             
            from langchain.document_loaders import PyPDFLoader
         | 
| 8 |  | 
|  | |
| 48 |  | 
| 49 |  | 
| 50 | 
             
            # Modify vectordb initialization to be dynamic based on user_id
         | 
| 51 | 
            +
            async def get_vectordb_for_user(user_collection_name):
         | 
| 52 | 
             
                # Get Chromadb location
         | 
| 53 | 
             
                CHROMADB_LOC = os.getenv('CHROMADB_LOC')
         | 
| 54 |  | 
|  | |
| 61 |  | 
| 62 | 
             
            vectordb_cache = {}
         | 
| 63 |  | 
| 64 | 
            +
            async def get_vectordb_for_user_cached(user_collection_name):
         | 
| 65 | 
             
                if user_collection_name not in vectordb_cache:
         | 
| 66 | 
            +
                    vectordb_cache[user_collection_name] = await get_vectordb_for_user(user_collection_name)
         | 
| 67 | 
             
                return vectordb_cache[user_collection_name]
         | 
| 68 |  | 
| 69 |  | 
|  | |
| 94 | 
             
                return(vectordb)
         | 
| 95 | 
             
                #return collection  # Return the collection as the asset
         | 
| 96 |  | 
| 97 | 
            +
             | 
| 98 | 
            +
            # Assuming LlamaModelSingleton is updated to support async instantiation
         | 
| 99 | 
             
            class LlamaModelSingleton:
         | 
| 100 | 
             
                _instance = None
         | 
| 101 |  | 
| 102 | 
            +
                @classmethod
         | 
| 103 | 
            +
                async def get_instance(cls):
         | 
| 104 | 
             
                    if cls._instance is None:
         | 
| 105 | 
            +
                        cls._instance = cls._load_llm()  # Assuming _load_llm is synchronous, if not, use an executor
         | 
| 106 | 
            +
                    return cls._instance
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                @staticmethod
         | 
| 109 | 
            +
                def _load_llm():
         | 
| 110 | 
            +
                    print('Loading LLM model...')
         | 
| 111 | 
            +
                    model_path = os.getenv("MODEL_PATH")
         | 
| 112 | 
            +
                    llm = LlamaCpp(
         | 
| 113 | 
            +
                        model_path=model_path,
         | 
| 114 | 
            +
                        n_gpu_layers=-1,
         | 
| 115 | 
            +
                        n_batch=512,
         | 
| 116 | 
            +
                        temperature=0.1,
         | 
| 117 | 
            +
                        top_p=1,
         | 
| 118 | 
            +
                        max_tokens=2000,
         | 
| 119 | 
            +
                    )
         | 
| 120 | 
            +
                    print(f'Model loaded from {model_path}')
         | 
| 121 | 
            +
                    return llm
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            async def load_llm():
         | 
| 124 | 
            +
                return await LlamaModelSingleton.get_instance()
         | 
| 125 | 
            +
             | 
|  | |
| 126 |  | 
| 127 |  | 
| 128 |  | 
| 129 | 
             
            #step 5, to instantiate once to create default_chain,router_chain,destination_chains into chain and set vectordb. so will not re-create per prompt
         | 
| 130 | 
            +
            async def default_chain(llm, user_collection_name):
         | 
| 131 | 
             
                # Get Chromadb location
         | 
| 132 | 
             
                CHROMADB_LOC = os.getenv('CHROMADB_LOC')
         | 
| 133 |  | 
| 134 | 
            +
                vectordb = await get_vectordb_for_user_cached(user_collection_name)  # Use the dynamic vectordb based on user_id
         | 
| 135 | 
             
                sum_template = """
         | 
| 136 | 
             
                As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
         | 
| 137 |  | 
|  | |
| 212 | 
             
                return default_chain,router_chain,destination_chains
         | 
| 213 |  | 
| 214 | 
             
            # Adjust llm_infer to accept user_id and use it for user-specific processing
         | 
| 215 | 
            +
            async def llm_infer(user_collection_name, prompt):
         | 
| 216 |  | 
| 217 | 
            +
                llm = await load_llm()  # load_llm is singleton for entire system
         | 
| 218 |  | 
| 219 | 
            +
                vectordb = await get_vectordb_for_user_cached(user_collection_name) # Vector collection for each us. 
         | 
| 220 |  | 
| 221 | 
            +
                default_chain, router_chain, destination_chains = await get_or_create_chain(user_collection_name, llm)  # Now user-specific
         | 
| 222 |  | 
| 223 | 
             
                chain = MultiPromptChain(
         | 
| 224 | 
             
                    router_chain=router_chain,
         | 
|  | |
| 234 | 
             
            # Assuming a simplified caching mechanism for demonstration
         | 
| 235 | 
             
            chain_cache = {}
         | 
| 236 |  | 
| 237 | 
            +
            async def get_or_create_chain(user_collection_name, llm):
         | 
| 238 | 
             
                if 'default_chain' in chain_cache and 'router_chain' in chain_cache:
         | 
| 239 | 
             
                    default_chain = chain_cache['default_chain']
         | 
| 240 | 
             
                    router_chain = chain_cache['router_chain']
         | 
| 241 | 
             
                    destination_chains = chain_cache['destination_chains']
         | 
| 242 | 
             
                else:
         | 
| 243 | 
            +
                    vectordb = await get_vectordb_for_user_cached(user_collection_name)  # User-specific vector database
         | 
| 244 | 
             
                    sum_template = """
         | 
| 245 | 
             
                    As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
         | 
| 246 |  |