anubhav77 commited on
Commit
ebd06cc
1 Parent(s): 1ee9c28
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src/.env
2
+ pytest.ini
3
+ logger.log
4
+ .pytest_cache/**
5
+ **/__pycache__/**
6
+ .vscode/**
7
+ **/.env
8
+ src/baseInfra/*.env
9
+ **/*.env*
10
+ **/*.pyc
11
+ src/baseInfra/**.env**
12
+ src/baseInfra/__pycache__/*.pyc
13
+ src/*.old.py
14
+ src/toolSelector.py
15
+ test/__pycache__/*.pyc
16
+ src/baseInfra/memoryManager.py
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11.0
2
+
3
+ ENV PYTHONUNBUFFERED 1
4
+
5
+ EXPOSE 8000
6
+
7
+ RUN useradd -m -u 1000 user
8
+ USER user
9
+ ENV HOME=/home/user \
10
+ PATH=/home/user/.local/bin:$PATH
11
+ ENV TZ="Asia/Kolkata"
12
+ WORKDIR $HOME/app
13
+
14
+ COPY requirements.txt ./
15
+ RUN pip install --upgrade pip && \
16
+ pip install -r requirements.txt
17
+
18
+
19
+ COPY --chown=user . $HOME/app
20
+ ENV PYTHONPATH=.:/home/user/app/src:$PYTHONPATH
21
+ RUN ls -al
22
+ RUN python --version
23
+ #CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
24
+
25
+ CMD ["python", "./src/main.py", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: pink
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
logging_config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 1
2
+ formatters:
3
+ simple:
4
+ format: '%(asctime)s - %(filename)s - %(funcName)10s() - %(lineno)d - %(levelname)s - %(message)s'
5
+ handlers:
6
+ console:
7
+ class: logging.StreamHandler
8
+ level: DEBUG
9
+ formatter: simple
10
+ stream: ext://sys.stdout
11
+ logfile:
12
+ class: logging.FileHandler
13
+ level: DEBUG
14
+ formatter: simple
15
+ filename: logger.log
16
+ encoding: utf8
17
+ loggers: #These are for individual files and then finally for root
18
+ objectiveHandler:
19
+ level: DEBUG
20
+ handlers: [console,logfile]
21
+ propagate: no
22
+ taskExecutor:
23
+ level: DEBUG
24
+ handlers: [console,logfile]
25
+ propagate: no
26
+ taskPlanner:
27
+ level: DEBUG
28
+ handlers: [console,logfile]
29
+ propagate: no
30
+ connectionpool:
31
+ level: ERROR
32
+ handlers: [console,logfile]
33
+ propagate: no
34
+ root:
35
+ level: DEBUG
36
+ handlers: [console,logfile]
37
+
main.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import logging
3
+ import fastapi
4
+ from fastapi import Body, Depends
5
+ import uvicorn
6
+ from fastapi import HTTPException , status
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi import FastAPI as Response
10
+ from sse_starlette.sse import EventSourceResponse
11
+ from starlette.responses import StreamingResponse
12
+ from starlette.requests import Request
13
+ from pydantic import BaseModel
14
+ from enum import Enum
15
+ from typing import List, Dict, Any, Generator, Optional, cast, Callable
16
+ from indexer import *
17
+
18
+
19
+
20
+ async def catch_exceptions_middleware(
21
+ request: Request, call_next: Callable[[Request], Any]
22
+ ) -> Response:
23
+ try:
24
+ return await call_next(request)
25
+ except Exception as e:
26
+ return JSONResponse(content={"error": repr(e)}, status_code=500)
27
+
28
+
29
+
30
+ app = fastapi.FastAPI(title="Maya Persistet")
31
+ app.middleware("http")(catch_exceptions_middleware)
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"],
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+ api_base="/api/v1"
40
+
41
+ @app.post(api_base+"/getMatchingDocs")
42
+ async def get_matching_docs(inStr: str ) -> Any:
43
+ """
44
+ """
45
+ return getRelevantDocs(inStr)
46
+
47
+
48
+
49
+
50
+ print(__name__)
51
+
52
+ if __name__ == '__main__' or __name__ == "src.main":
53
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ uvicorn
2
+ bs4
3
+ lxml
4
+ fastapi
5
+ loguru
6
+ chromadb
7
+ langchain
8
+ sentence_transformers
9
+ InstructorEmbedding
10
+ sse_starlette
11
+ dropbox
12
+ firebase
13
+ together
14
+ python-dotenv
15
+ pytest
16
+ pytest-env
17
+ pydantic==1.10.9
18
+ openapi-schema-pydantic
19
+ markdown
20
+ google-generativeai
src/baseInfra/dbInterface.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from firebase import firebase
2
+ import os,json
3
+ import baseInfra.firebase_handler as fbh
4
+
5
+ secret=os.environ['FIREBASE_TOKEN']
6
+ user_id="131251"
7
+ auth=firebase.FirebaseAuthentication(secret=secret,email="anubhav77@gmail.com",extra={"id":user_id})
8
+ fb_url="https://device-1a455.firebaseio.com"
9
+ fb= firebase.FirebaseApplication(fb_url,auth)
10
+ base_dir="/users/"+user_id
11
+
12
+ class DbInterface:
13
+ """
14
+ A database interface.
15
+
16
+ This class provides a dummy implementation of a database interface. It has two methods: `get_task_list()` and `update_task_list()`.
17
+ """
18
+ def __init__(self):
19
+ self.base_dir=base_dir
20
+ self.conf_loc="config"
21
+ self.cache_loc="vs_cache"
22
+
23
+ def _set_base_dir(self,sub_dir):
24
+ """
25
+ This internal method is primarily for unit testcases
26
+ Adds extra subdir to the path
27
+ """
28
+ if sub_dir.startswith("/"):
29
+ self.base_dir=base_dir+sub_dir
30
+ else:
31
+ self.base_dir=base_dir+"/"+sub_dir
32
+ if self.base_dir.endswith("/"):
33
+ self.base_dir=self.base_dir[:-1]
34
+ return self.base_dir
35
+
36
+
37
+ def get_config(self,path):
38
+ """
39
+ Gets the config from the database given path
40
+ This is used for getting configs of executor, planner or tools
41
+
42
+ Returns:
43
+ The json config object.
44
+
45
+ """
46
+ print(self.base_dir+"/"+self.conf_loc)
47
+ print(path)
48
+ config = fb.get(self.base_dir+"/"+self.conf_loc, path)
49
+ if config is None:
50
+ return {}
51
+ return config
52
+
53
+ def get_matching_cache(self,input):
54
+ """
55
+ Gets the matching cache from the database given input
56
+ This is used for getting similar documents from the vector store
57
+
58
+ Returns:
59
+ The json cache object.
60
+
61
+ """
62
+ fb_key=fbh.convert_to_firebase_key(input)
63
+ cache = fb.get(self.base_dir+"/"+self.cache_loc, fb_key)
64
+ if cache == None:
65
+ return []
66
+ else:
67
+ for item in cache: #ie cache is list of dicts {'input':inStr,'value':cached_items_list}
68
+ if item['input'] == input:
69
+ return item['value']
70
+
71
+ return []
72
+
73
+ def add_to_cache(self,input,value):
74
+ """
75
+ Adds the input and value to the cache
76
+ This is used for adding documents to the vector store
77
+
78
+ Returns:
79
+ The True if item in cache was updated and false if a new item was added
80
+
81
+ """
82
+ retVal=False
83
+ fb_key=fbh.convert_to_firebase_key(input)
84
+ cache = fb.get(self.base_dir+"/"+self.cache_loc, fb_key)
85
+ if cache is None:
86
+ cache = []
87
+ else:
88
+ for item in cache: #ie cache is list of dicts {'input':inStr,'value':cached_items_list}
89
+ if item['input'] == input:
90
+ item['value']=value
91
+ retVal=True
92
+ cache.append({'input':input,'value':value})
93
+ fb.patch(self.base_dir+"/"+self.cache_loc, {fb_key:cache})
94
+ return retVal
95
+
96
+
97
+
src/baseInfra/dropbox_handler.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import time
3
+ import dropbox
4
+ from dropbox.files import WriteMode
5
+ from dropbox.exceptions import ApiError, AuthError
6
+ import sys,os
7
+ import baseInfra.firebase_handler as fbh
8
+
9
+ TOKEN=fbh.fb_get("d2_accesstoken")
10
+ APP_KEY=os.environ['DROPBOX_APP_KEY']
11
+ APP_SECRET=os.environ['DROPBOX_APP_SECRET']
12
+ REFRESH_TOKEN=fbh.fb_get("d2_refreshtoken")
13
+
14
+ #os.environ['DROP_DIR2']="C:/dockers/chroma/chroma1/"
15
+ #os.environ['APP_PATH']="/"
16
+ #print("token::",TOKEN)
17
+
18
+ with dropbox.Dropbox(oauth2_access_token=TOKEN,app_key=APP_KEY,app_secret=APP_SECRET,oauth2_refresh_token=REFRESH_TOKEN) as dbx: #,app_key=APP_KEY,app_secret=APP_SECRET,oauth2_refresh_token=REFRESH_TOKEN) as dbx:
19
+ # Check that the access token is valid
20
+ try:
21
+ dbx.users_get_current_account()
22
+ if (TOKEN != dbx._oauth2_access_token):
23
+ fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
24
+ TOKEN=dbx._oauth2_access_token
25
+ print("dropbox connection ok,",dbx._oauth2_access_token)
26
+ print(dbx._oauth2_refresh_token)
27
+ except AuthError:
28
+ try:
29
+ dbx.check_and_refresh_access_token()
30
+ fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
31
+ print("dropbox connection refreshed and updated",dbx._oauth2_access_token)
32
+ print(dbx._oauth2_refresh_token)
33
+ except Exception:
34
+ sys.exit("ERROR: Invalid access token; try re-generating an "
35
+ "access token from the app console on the web.")
36
+
37
+ def normalizeFilename(filename):
38
+ while '//' in filename:
39
+ filename = filename.replace('//', '/')
40
+ return filename
41
+
42
+ def getDropboxFilename(localFilename):
43
+ """ localFilename is $DROP_DIR2/<subpath>/<filename>"""
44
+ """ dropboxFilename is $APP_PATH/<subpath>/<filename"""
45
+ #if not localFilename.startswith(os.environ['DROP_DIR2']):
46
+ # localFilename=os.environ['DROP_DIR2']+localFilename
47
+ localFilename=normalizeFilename(localFilename)
48
+ return normalizeFilename(localFilename.replace(os.environ['DROP_DIR2'],"/",1).replace("/",os.environ['APP_PATH'],1))
49
+
50
+ def getLocalFilename(dropboxFilename):
51
+ """ localFilename is $DROP_DIR2/<subpath>/<filename>"""
52
+ """ dropboxFilename is $APP_PATH/<subpath>/<filename"""
53
+ #if not dropboxFilename.startswith(os.environ['APP_PATH']):
54
+ # dropboxFilename=os.environ['APP_PATH']+dropboxFilename
55
+ dropboxFilename=normalizeFilename(dropboxFilename)
56
+ return normalizeFilename(dropboxFilename.replace(os.environ['APP_PATH'],"/",1).replace("/",os.environ['DROP_DIR2'],1))
57
+
58
+ def backupFile(localFilename):
59
+ """Upload a file.
60
+ Return the request response, or None in case of error.
61
+ This will also create directory on dropbox if needed
62
+ """
63
+ global TOKEN
64
+ localFilename=normalizeFilename(localFilename)
65
+ dropboxFilename=getDropboxFilename(localFilename)
66
+ print("backing file ",localFilename," to ",dropboxFilename)
67
+ mode = dropbox.files.WriteMode.overwrite
68
+ mtime = os.path.getmtime(localFilename)
69
+ with open(localFilename, 'rb') as f:
70
+ data = f.read()
71
+ try:
72
+ res = dbx.files_upload(
73
+ data, dropboxFilename, mode,
74
+ client_modified=datetime.datetime(*time.gmtime(mtime)[:6]),
75
+ mute=True)
76
+ if (TOKEN != dbx._oauth2_access_token):
77
+ fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
78
+ TOKEN=dbx._oauth2_access_token
79
+ print(dbx._oauth2_refresh_token)
80
+ except dropbox.exceptions.ApiError as err:
81
+ print('*** API error', err)
82
+ return None
83
+ print('uploaded as', res.name.encode('utf8'))
84
+ return res
85
+
86
+ def restoreFile(dropboxFilename):
87
+ """Download a file.
88
+ Return the bytes of the file, or None if it doesn't exist.
89
+ Will create dir+subdirs if possible
90
+ """
91
+ global TOKEN
92
+ dropboxFilename=normalizeFilename(dropboxFilename)
93
+ localFilename=getLocalFilename(dropboxFilename)
94
+ print("restoring file ",localFilename," from ",dropboxFilename)
95
+ try:
96
+ md, res = dbx.files_download(dropboxFilename)
97
+ if (TOKEN != dbx._oauth2_access_token):
98
+ fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
99
+ TOKEN=dbx._oauth2_access_token
100
+ print(dbx._oauth2_refresh_token)
101
+ except dropbox.exceptions.HttpError as err:
102
+ print('*** HTTP error', err)
103
+ return None
104
+ data = res.content
105
+ print(len(data), 'bytes; md:', md)
106
+ localdir=os.path.dirname(localFilename)
107
+ if not os.path.exists(localdir):
108
+ os.makedirs(localdir)
109
+ with open(localFilename, 'wb') as f:
110
+ f.write(data)
111
+ return data
112
+
113
+ def backupFolder(localFolder):
114
+ """ list all files in folder and subfolder and upload them"""
115
+ print("backup folder called for ",localFolder)
116
+ if not localFolder.startswith(os.environ['DROP_DIR2']):
117
+ localFolder=os.environ['DROP_DIR2']+localFolder
118
+ filenames=[]
119
+ for (root,dirs,files) in os.walk(localFolder, topdown=True):
120
+ print(root)
121
+ for filename in files:
122
+ filenames.append(root+"/"+filename)
123
+ print(root+"/"+filename)
124
+ backupFile(root+"/"+filename)
125
+
126
+
127
+ def restoreFolder(dropboxFolder):
128
+ """ list all files in dropbox folder and subfolders and restore them"""
129
+ global TOKEN
130
+ if not dropboxFolder.startswith(os.environ['APP_PATH']):
131
+ dropboxFolder=os.environ['APP_PATH']+dropboxFolder
132
+ try:
133
+ res=dbx.files_list_folder(dropboxFolder)
134
+ if (TOKEN != dbx._oauth2_access_token):
135
+ fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
136
+ TOKEN=dbx._oauth2_access_token
137
+ print(dbx._oauth2_refresh_token)
138
+ except dropbox.exceptions.ApiError as err:
139
+ print('Folder listing failed for', dropboxFolder, '-- assumed empty:', err)
140
+ return
141
+ except dropbox.exceptions.AuthError as err1:
142
+ print('Folder listing failed for', dropboxFolder, '-- assumed empty:', err1)
143
+ return
144
+ for entry in res.entries:
145
+ if (isinstance(entry, dropbox.files.FileMetadata)):
146
+ restoreFile(entry.path_display)
147
+ else:
148
+ try:
149
+ restoreFolder(entry.path_display)
150
+ except Exception:
151
+ print("Error restoring folder,",entry.path_display)
152
+ print(entry.path_display)
src/baseInfra/firebase_handler.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv()
3
+ from firebase import firebase
4
+ import os
5
+
6
+ secret=os.environ['FIREBASE_TOKEN']
7
+ user_id="db" #this file is for dropbox token management through firebase only
8
+ auth=firebase.FirebaseAuthentication(secret=secret,email="anubhav77@gmail.com",extra={"id":user_id})
9
+ fb_url="https://device-1a455.firebaseio.com"
10
+ fb= firebase.FirebaseApplication(fb_url,auth)
11
+ base_dir="/users/"+user_id
12
+
13
+ def fb_get(item):
14
+ """ item need to be in format d2_refreshtoken same name will be present in the firebase
15
+ d2 means second dropbox instance or app.
16
+ """
17
+ return fb.get(base_dir,item)
18
+
19
+ def fb_update(item,value):
20
+ """ item need to be in format d2_refreshtoken same name will be present in the firebase
21
+ d2 means second dropbox instance or app.
22
+ """
23
+ return fb.patch(base_dir,{item:value})
24
+
25
+ import re
26
+
27
+ def convert_to_firebase_key(inStr):
28
+ """Converts a string to a Firebase key.
29
+
30
+ Args:
31
+ inStr: The string to convert.
32
+
33
+ Returns:
34
+ A Firebase key, or None if the inStr is not a valid Firebase key.
35
+ """
36
+
37
+ # Firebase keys must be between 3 and 128 characters long.
38
+ # They must start with a letter, number, or underscore.
39
+ # They can only contain letters, numbers, underscores, hyphens, and periods.
40
+ # They cannot contain spaces or other special characters.
41
+
42
+ # Convert the inStr to lowercase.
43
+ outStr = inStr.lower()
44
+
45
+ # Replace all spaces with underscores.
46
+ outStr = outStr.replace(' ', '_')
47
+
48
+ # Remove any leading or trailing underscores.
49
+ outStr = outStr.lstrip('_').rstrip('_')
50
+
51
+ # Remove any character not matching firebase_key_Regex.
52
+ outStr=re.sub(r'[^\w\-.]', '', outStr)
53
+
54
+
55
+ firebase_key_regex = re.compile(r'^[a-zA-Z0-9_][a-zA-Z0-9-_.]*$')
56
+ # The following string values are not allowed for Firebase keys:
57
+ # - "firebase"
58
+ # - "google"
59
+ # - "android"
60
+ # - "ios"
61
+ # - "web"
62
+ # - "console"
63
+ # - "auth"
64
+ # - "database"
65
+ # - "storage"
66
+ # - "hosting"
67
+ # - "functions"
68
+ # - "firestore"
69
+ # - "realtime-database"
70
+ # - "remote-config"
71
+ # - "analytics"
72
+ # - "crashlytics"
73
+ # - "performance-monitoring"
74
+ # - "test-lab"
75
+ # - "cloud-messaging"
76
+ # - "dynamic-links"
77
+ # - "identity-toolkit"
78
+ # - "cloud-functions"
79
+ # - "cloud-firestore"
80
+ # - "cloud-storage"
81
+ # - "cloud-hosting"
82
+ # - "cloud-functions-v2"
83
+ # - "cloud-firestore-v2"
84
+ # - "cloud-storage-v2"
85
+ # - "cloud-hosting-v2"
86
+ # - "cloud-functions-v3"
87
+ # - "cloud-firestore-v3"
88
+ # - "cloud-storage-v3"
89
+ # - "cloud-hosting-v3"
90
+
91
+ if not firebase_key_regex.match(inStr) or inStr in ['firebase', 'google', 'android', 'ios', 'web', 'console', 'auth', 'database', 'storage', 'hosting', 'functions', 'firestore', 'realtime-database', 'remote-config', 'analytics', 'crashlytics', 'performance-monitoring', 'test-lab', 'cloud-messaging', 'dynamic-links', 'identity-toolkit', 'cloud-functions', 'cloud-firestore', 'cloud-storage', 'cloud-hosting', 'cloud-functions-v2', 'cloud-firestore-v2', 'cloud-storage-v2', 'cloud-hosting-v2', 'cloud-functions-v3', 'cloud-firestore-v3', 'cloud-storage-v3', 'cloud-hosting-v3']:
92
+ return "fb_"+outStr
93
+
94
+ if len(outStr) > 120:
95
+ outStr = "fbx_"+outStr[:120]
96
+
97
+ return outStr
src/indexer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.vectorstores import Chroma
2
+ from chromadb.api.fastapi import requests
3
+ from langchain.schema import Document
4
+ from langchain.chains import RetrievalQA
5
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
6
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
7
+ from langchain.chains.query_constructor.base import AttributeInfo
8
+ from llm.llmFactory import LLMFactory
9
+
10
+ model_name = "BAAI/bge-base-en"
11
+ encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
12
+
13
+ embedding = HuggingFaceBgeEmbeddings(
14
+ model_name=model_name,
15
+ model_kwargs={'device': 'cpu'},
16
+ encode_kwargs=encode_kwargs
17
+ )
18
+
19
+ persist_directory = 'db'
20
+ docs = [
21
+ Document(
22
+ page_content="Complex, layered, rich red with dark fruit flavors",
23
+ metadata={"name":"Opus One", "year": 2018, "rating": 96, "grape": "Cabernet Sauvignon", "color":"red", "country":"USA"},
24
+ ),
25
+ Document(
26
+ page_content="Luxurious, sweet wine with flavors of honey, apricot, and peach",
27
+ metadata={"name":"Château d'Yquem", "year": 2015, "rating": 98, "grape": "Sémillon", "color":"white", "country":"France"},
28
+ ),
29
+ Document(
30
+ page_content="Full-bodied red with notes of black fruit and spice",
31
+ metadata={"name":"Penfolds Grange", "year": 2017, "rating": 97, "grape": "Shiraz", "color":"red", "country":"Australia"},
32
+ ),
33
+ Document(
34
+ page_content="Elegant, balanced red with herbal and berry nuances",
35
+ metadata={"name":"Sassicaia", "year": 2016, "rating": 95, "grape": "Cabernet Franc", "color":"red", "country":"Italy"},
36
+ ),
37
+ Document(
38
+ page_content="Highly sought-after Pinot Noir with red fruit and earthy notes",
39
+ metadata={"name":"Domaine de la Romanée-Conti", "year": 2018, "rating": 100, "grape": "Pinot Noir", "color":"red", "country":"France"},
40
+ ),
41
+ Document(
42
+ page_content="Crisp white with tropical fruit and citrus flavors",
43
+ metadata={"name":"Cloudy Bay", "year": 2021, "rating": 92, "grape": "Sauvignon Blanc", "color":"white", "country":"New Zealand"},
44
+ ),
45
+ Document(
46
+ page_content="Rich, complex Champagne with notes of brioche and citrus",
47
+ metadata={"name":"Krug Grande Cuvée", "year": 2010, "rating": 93, "grape": "Chardonnay blend", "color":"sparkling", "country":"New Zealand"},
48
+ ),
49
+ Document(
50
+ page_content="Intense, dark fruit flavors with hints of chocolate",
51
+ metadata={"name":"Caymus Special Selection", "year": 2018, "rating": 96, "grape": "Cabernet Sauvignon", "color":"red", "country":"USA"},
52
+ ),
53
+ Document(
54
+ page_content="Exotic, aromatic white with stone fruit and floral notes",
55
+ metadata={"name":"Jermann Vintage Tunina", "year": 2020, "rating": 91, "grape": "Sauvignon Blanc blend", "color":"white", "country":"Italy"},
56
+ ),
57
+ ]
58
+
59
+ vectorstore = Chroma.from_documents(documents=docs,
60
+ embedding=embedding,
61
+ persist_directory=persist_directory)
62
+
63
+ metadata_field_info = [
64
+ AttributeInfo(
65
+ name="grape",
66
+ description="The grape used to make the wine",
67
+ type="string or list[string]",
68
+ ),
69
+ AttributeInfo(
70
+ name="name",
71
+ description="The name of the wine",
72
+ type="string or list[string]",
73
+ ),
74
+ AttributeInfo(
75
+ name="color",
76
+ description="The color of the wine",
77
+ type="string or list[string]",
78
+ ),
79
+ AttributeInfo(
80
+ name="year",
81
+ description="The year the wine was released",
82
+ type="integer",
83
+ ),
84
+ AttributeInfo(
85
+ name="country",
86
+ description="The name of the country the wine comes from",
87
+ type="string",
88
+ ),
89
+ AttributeInfo(
90
+ name="rating", description="The Robert Parker rating for the wine 0-100", type="integer" #float
91
+ ),
92
+ ]
93
+ document_content_description = "Brief description of the wine"
94
+ lf=LLMFactory()
95
+ llm=lf.get_llm("executor2")
96
+
97
+ retriever = SelfQueryRetriever.from_llm(
98
+ llm,
99
+ vectorstore,
100
+ document_content_description,
101
+ metadata_field_info,
102
+ verbose=True
103
+ )
104
+
105
+ def getRelevantDocs(query:str):
106
+ return retriever.get_relevant_documents(query)
src/llm/hostedLLM.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Mapping, Optional, Dict
2
+ from pydantic import Extra, Field #, root_validator, model_validator
3
+ import os,json
4
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
5
+ from langchain.llms.base import LLM
6
+ import requests
7
+
8
+
9
+ class HostedLLM(LLM):
10
+ """
11
+ Hosted LLMs in huggingface spaces with fastAPI. Interface is primarily rest call with huggingface token
12
+
13
+ Attributes:
14
+ url: is the url of the endpoint
15
+ http_method: which http_method need to be invoked [get or post]
16
+ model_name: which model is being hosted
17
+ temperature: temperature between 0 to 1
18
+ max_tokens: amount of output to generate, 512 by default
19
+ api_token: api_token to be passed for bearer authorization. Defaults to huggingface_api enviorment variable.
20
+ verbose: for extra logging
21
+ """
22
+ url: str = ""
23
+ http_method: Optional[str] = "post"
24
+ model_name: Optional[str] = "bard"
25
+ api_token: Optional[str] = os.environ["HUGGINGFACE_API"]
26
+ temperature: float = 0.7
27
+ max_tokens: int = 512
28
+ verbose: Optional[bool] = False
29
+ class Config:
30
+ extra = Extra.forbid
31
+
32
+ #@model_validator(mode="after")
33
+ #def validate_environment(cls, values: Dict) -> Dict:
34
+ # if values["http_method"].strip() == "GET" or values["http_method"].strip() == "get":
35
+ # values["http_method"]="get"
36
+ # else:
37
+ # values["http_method"]="post"
38
+ # if values["api_token"] == "":
39
+ # values["api_token"] = os.environ["HUGGINGFACE_API"]
40
+ #
41
+ # return values
42
+
43
+ @property
44
+ def _llm_type(self) -> str:
45
+ return "text2text-generation"
46
+
47
+ def _call(
48
+ self,
49
+ prompt: str,
50
+ stop: Optional[List[str]] = None,
51
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
52
+ ) -> str:
53
+ if run_manager:
54
+ run_manager.on_text([prompt])
55
+ #messages={"messages":[{"role":"user","content":prompt}]}
56
+ prompt={"prompt":prompt}
57
+ headers = {
58
+ "Authorization": f"Bearer {self.api_token}",
59
+ "Content-Type": "application/json",
60
+ }
61
+ if(self.http_method=="post"):
62
+ response=requests.post(self.url,json=prompt,headers=headers)
63
+ else:
64
+ response=requests.get(self.url,json=prompt,headers=headers)
65
+ val=json.loads(response.text)['content']
66
+ if run_manager:
67
+ run_manager.on_llm_end(val)
68
+ return val
69
+
70
+ @property
71
+ def _identifying_params(self) -> Mapping[str, Any]:
72
+ """Get the identifying parameters."""
73
+ return {"name": self.model_name, "type": "hosted"}
74
+
75
+ def extractJson(self,val:str) -> Any:
76
+ """Helper function to extract json from this LLMs output"""
77
+ #This is assuming the json is the first item within ````
78
+ #my super LLM will sometime send the json directly
79
+ try:
80
+ v3=val.replace("\n","").replace("\r","")
81
+ v4=json.loads(v3)
82
+ except:
83
+ v2=val.replace("```json","```").split("```")[1]
84
+ v3=v2.replace("\n","").replace("\r","")
85
+ v4=json.loads(v3)
86
+ return v4
87
+
88
+ def extractPython(self,val:str) -> Any:
89
+ """Helper function to extract python from this LLMs output"""
90
+ #This is assuming the python is the first item within ````
91
+ v2=val.replace("```python","```").split("```")[1]
92
+ return v2
src/llm/llmFactory.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from baseInfra.dbInterface import DbInterface
3
+ from llm.hostedLLM import HostedLLM
4
+ from llm.togetherLLM import TogetherLLM
5
+ from llm.palmLLM import PalmLLM
6
+
7
+
8
+ class LLMFactory:
9
+ """
10
+ Factory class for creating LLM objects.
11
+ """
12
+
13
+ def __init__(self):
14
+ """
15
+ Constructor for the LLMFactory class.
16
+
17
+ Args:
18
+ db_interface: The DBInterface object to use for getting LLM configs.
19
+ """
20
+ self._db_interface = DbInterface()
21
+
22
+ def get_llm(self, llm_path: str) -> object:
23
+ """
24
+ Gets an LLM object of the specified type.
25
+
26
+ Args:
27
+ llm_path: The path to the LLM config.
28
+
29
+ Returns:
30
+ The LLM object.
31
+ """
32
+ logger = logging.getLogger(__name__)
33
+ try:
34
+ config = self._db_interface.get_config(llm_path)
35
+ logger.debug(llm_path)
36
+ logger.debug(config)
37
+ llm_type = config["llm_type"]
38
+ llm_config=config["llm_config"]
39
+ except Exception as ex:
40
+ logger.exception("Exception in getLLM")
41
+ logger.exception(ex)
42
+ config={}
43
+ llm_type=""
44
+ llm_config={}
45
+
46
+ if llm_type == "hostedLLM":
47
+ return HostedLLM(**llm_config)
48
+ elif llm_type == "togetherLLM":
49
+ return TogetherLLM(**llm_config)
50
+ elif llm_type == "palmLLM":
51
+ return PalmLLM(**llm_config)
52
+ else:
53
+ logger.error(f"Invalid LLM type: {llm_type}")
54
+ raise ValueError(f"Invalid LLM type: {llm_type}")
55
+
src/llm/palmLLM.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Mapping, Optional, Dict
2
+ from pydantic import Extra, Field #, root_validator, model_validator
3
+ import os,json
4
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
5
+ from langchain.llms.base import LLM
6
+ import google.generativeai as palm
7
+ from google.generativeai import types
8
+ import ast
9
+ #from langchain.llms import GooglePalm
10
+ import requests
11
+
12
+
13
+ class PalmLLM(LLM):
14
+
15
+ model_name: str = "text-bison-001"
16
+ temperature: float = 0
17
+ max_tokens: int = 2048
18
+ stop: Optional[List] = []
19
+ prev_prompt: Optional[str]=""
20
+ prev_stop: Optional[str]=""
21
+ prev_run_manager:Optional[Any]=None
22
+
23
+ def __init__(
24
+ self,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ palm.configure()
29
+ #self.model = palm.Text2Text(self.model_name)
30
+
31
+ @property
32
+ def _llm_type(self) -> str:
33
+ return "text2text-generation"
34
+
35
+ def _call(
36
+ self,
37
+ prompt: str,
38
+ stop: Optional[List[str]] = None,
39
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
40
+ ) -> str:
41
+ self.prev_prompt=prompt
42
+ self.prev_stop=stop
43
+ self.prev_run_manager=run_manager
44
+ if stop == None:
45
+ stop=self.stop
46
+ text=palm.generate_text(prompt=prompt,stop_sequences=self.stop,
47
+ temperature=self.temperature, max_output_tokens=self.max_tokens,
48
+ safety_settings=[{"category":0,"threshold":4},
49
+ {"category":1,"threshold":4},
50
+ {"category":2,"threshold":4},
51
+ {"category":3,"threshold":4},
52
+ {"category":4,"threshold":4},
53
+ {"category":5,"threshold":4},
54
+ {"category":6, "threshold":4}]
55
+ )
56
+ print("Response from palm",text)
57
+ val=text.result
58
+ if run_manager:
59
+ run_manager.on_llm_end(val)
60
+ return val
61
+
62
+ @property
63
+ def _identifying_params(self) -> Mapping[str, Any]:
64
+ """Get the identifying parameters."""
65
+ return {"name": self.model_name, "type": "palm"}
66
+
67
+ def extractJson(self,val:str) -> Any:
68
+ """Helper function to extract json from this LLMs output"""
69
+ #This is assuming the json is the first item within ````
70
+ # palm is responding always with ```json and ending with ```, however sometimes response is not complete
71
+ # in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
72
+ try:
73
+ count=0
74
+ while val.startswith("```json") and not val.endswith("```") and count<7:
75
+ val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager)
76
+ count+=1
77
+ v2=val.replace("```json","```").split("```")[1]
78
+ #v3=v2.replace("\n","").replace("\r","").replace("'","\"")
79
+ v3=json.dumps(ast.literal_eval(v2))
80
+ v4=json.loads(v3)
81
+ except:
82
+ v2=val.replace("\n","").replace("\r","")
83
+ v3=json.dumps(ast.literal_eval(val))
84
+ #v3=v2.replace("'","\"")
85
+ v4=json.loads(v3)
86
+ #v4=json.loads(v2)
87
+ return v4
88
+
89
+ def extractPython(self,val:str) -> Any:
90
+ """Helper function to extract python from this LLMs output"""
91
+ #This is assuming the python is the first item within ````
92
+ v2=val.replace("```python","```").split("```")[1]
93
+ return v2
src/llm/togetherLLM.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import together
2
+ import os
3
+ import logging,json
4
+ from typing import Any, Dict, List, Mapping, Optional
5
+
6
+ from pydantic import Extra, Field #, root_validator, model_validator
7
+
8
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
9
+ from langchain.llms.base import LLM
10
+ from langchain.llms.utils import enforce_stop_tokens
11
+ from langchain.utils import get_from_dict_or_env
12
+
13
+ class TogetherLLM(LLM):
14
+ """Together large language models."""
15
+
16
+ model_name: str = "togethercomputer/llama-2-70b-chat"
17
+ """model endpoint to use"""
18
+
19
+ together_api_key: str = os.environ["TOGETHER_API_KEY"]
20
+ """Together API key"""
21
+
22
+ temperature: float = 0.7
23
+ """What sampling temperature to use."""
24
+
25
+ max_tokens: int = 512
26
+ """The maximum number of tokens to generate in the completion."""
27
+
28
+ class Config:
29
+ extra = Extra.forbid
30
+
31
+ #@model_validator(mode="after")
32
+ #def validate_environment(cls, values: Dict) -> Dict:
33
+ # """Validate that the API key is set."""
34
+ # api_key = get_from_dict_or_env(
35
+ # values, "together_api_key", "TOGETHER_API_KEY"
36
+ # )
37
+ # values["together_api_key"] = api_key
38
+ # return values
39
+
40
+ @property
41
+ def _llm_type(self) -> str:
42
+ """Return type of LLM."""
43
+ return "together"
44
+
45
+ def _call(
46
+ self,
47
+ prompt: str,
48
+ **kwargs: Any,
49
+ ) -> str:
50
+ """Call to Together endpoint."""
51
+ together.api_key = self.together_api_key
52
+ output = together.Complete.create(prompt,
53
+ model=self.model_name,
54
+ max_tokens=self.max_tokens,
55
+ temperature=self.temperature,
56
+ )
57
+ text = output['output']['choices'][0]['text']
58
+ return text
59
+
60
+ def extractJson(self,val:str) -> Any:
61
+ """Helper function to extract json from this LLMs output"""
62
+ #This is assuming the json is the first item within ````
63
+ v2=val.replace("```json","```").split("```")[1]
64
+ v3=v2.replace("\n","").replace("\r","")
65
+ v4=json.loads(v3)
66
+ return v4
67
+
68
+ def extractPython(self,val:str) -> Any:
69
+ """Helper function to extract python from this LLMs output"""
70
+ #This is assuming the python is the first item within ````
71
+ v2=val.replace("```python","```").split("```")[1]
72
+ return v2