Spaces:
Runtime error
Runtime error
v0.1
Browse files- .gitignore +16 -0
- Dockerfile +25 -0
- README.md +1 -0
- logging_config.yaml +37 -0
- main.py +53 -0
- requirements.txt +20 -0
- src/baseInfra/dbInterface.py +97 -0
- src/baseInfra/dropbox_handler.py +152 -0
- src/baseInfra/firebase_handler.py +97 -0
- src/indexer.py +106 -0
- src/llm/hostedLLM.py +92 -0
- src/llm/llmFactory.py +55 -0
- src/llm/palmLLM.py +93 -0
- src/llm/togetherLLM.py +72 -0
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
src/.env
|
2 |
+
pytest.ini
|
3 |
+
logger.log
|
4 |
+
.pytest_cache/**
|
5 |
+
**/__pycache__/**
|
6 |
+
.vscode/**
|
7 |
+
**/.env
|
8 |
+
src/baseInfra/*.env
|
9 |
+
**/*.env*
|
10 |
+
**/*.pyc
|
11 |
+
src/baseInfra/**.env**
|
12 |
+
src/baseInfra/__pycache__/*.pyc
|
13 |
+
src/*.old.py
|
14 |
+
src/toolSelector.py
|
15 |
+
test/__pycache__/*.pyc
|
16 |
+
src/baseInfra/memoryManager.py
|
Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11.0
|
2 |
+
|
3 |
+
ENV PYTHONUNBUFFERED 1
|
4 |
+
|
5 |
+
EXPOSE 8000
|
6 |
+
|
7 |
+
RUN useradd -m -u 1000 user
|
8 |
+
USER user
|
9 |
+
ENV HOME=/home/user \
|
10 |
+
PATH=/home/user/.local/bin:$PATH
|
11 |
+
ENV TZ="Asia/Kolkata"
|
12 |
+
WORKDIR $HOME/app
|
13 |
+
|
14 |
+
COPY requirements.txt ./
|
15 |
+
RUN pip install --upgrade pip && \
|
16 |
+
pip install -r requirements.txt
|
17 |
+
|
18 |
+
|
19 |
+
COPY --chown=user . $HOME/app
|
20 |
+
ENV PYTHONPATH=.:/home/user/app/src:$PYTHONPATH
|
21 |
+
RUN ls -al
|
22 |
+
RUN python --version
|
23 |
+
#CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
24 |
+
|
25 |
+
CMD ["python", "./src/main.py", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
@@ -5,6 +5,7 @@ colorFrom: pink
|
|
5 |
colorTo: purple
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
5 |
colorTo: purple
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
+
app_port: 8000
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
logging_config.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 1
|
2 |
+
formatters:
|
3 |
+
simple:
|
4 |
+
format: '%(asctime)s - %(filename)s - %(funcName)10s() - %(lineno)d - %(levelname)s - %(message)s'
|
5 |
+
handlers:
|
6 |
+
console:
|
7 |
+
class: logging.StreamHandler
|
8 |
+
level: DEBUG
|
9 |
+
formatter: simple
|
10 |
+
stream: ext://sys.stdout
|
11 |
+
logfile:
|
12 |
+
class: logging.FileHandler
|
13 |
+
level: DEBUG
|
14 |
+
formatter: simple
|
15 |
+
filename: logger.log
|
16 |
+
encoding: utf8
|
17 |
+
loggers: #These are for individual files and then finally for root
|
18 |
+
objectiveHandler:
|
19 |
+
level: DEBUG
|
20 |
+
handlers: [console,logfile]
|
21 |
+
propagate: no
|
22 |
+
taskExecutor:
|
23 |
+
level: DEBUG
|
24 |
+
handlers: [console,logfile]
|
25 |
+
propagate: no
|
26 |
+
taskPlanner:
|
27 |
+
level: DEBUG
|
28 |
+
handlers: [console,logfile]
|
29 |
+
propagate: no
|
30 |
+
connectionpool:
|
31 |
+
level: ERROR
|
32 |
+
handlers: [console,logfile]
|
33 |
+
propagate: no
|
34 |
+
root:
|
35 |
+
level: DEBUG
|
36 |
+
handlers: [console,logfile]
|
37 |
+
|
main.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import logging
|
3 |
+
import fastapi
|
4 |
+
from fastapi import Body, Depends
|
5 |
+
import uvicorn
|
6 |
+
from fastapi import HTTPException , status
|
7 |
+
from fastapi.responses import JSONResponse
|
8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
9 |
+
from fastapi import FastAPI as Response
|
10 |
+
from sse_starlette.sse import EventSourceResponse
|
11 |
+
from starlette.responses import StreamingResponse
|
12 |
+
from starlette.requests import Request
|
13 |
+
from pydantic import BaseModel
|
14 |
+
from enum import Enum
|
15 |
+
from typing import List, Dict, Any, Generator, Optional, cast, Callable
|
16 |
+
from indexer import *
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
async def catch_exceptions_middleware(
|
21 |
+
request: Request, call_next: Callable[[Request], Any]
|
22 |
+
) -> Response:
|
23 |
+
try:
|
24 |
+
return await call_next(request)
|
25 |
+
except Exception as e:
|
26 |
+
return JSONResponse(content={"error": repr(e)}, status_code=500)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
app = fastapi.FastAPI(title="Maya Persistet")
|
31 |
+
app.middleware("http")(catch_exceptions_middleware)
|
32 |
+
app.add_middleware(
|
33 |
+
CORSMiddleware,
|
34 |
+
allow_origins=["*"],
|
35 |
+
allow_credentials=True,
|
36 |
+
allow_methods=["*"],
|
37 |
+
allow_headers=["*"],
|
38 |
+
)
|
39 |
+
api_base="/api/v1"
|
40 |
+
|
41 |
+
@app.post(api_base+"/getMatchingDocs")
|
42 |
+
async def get_matching_docs(inStr: str ) -> Any:
|
43 |
+
"""
|
44 |
+
"""
|
45 |
+
return getRelevantDocs(inStr)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
print(__name__)
|
51 |
+
|
52 |
+
if __name__ == '__main__' or __name__ == "src.main":
|
53 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
uvicorn
|
2 |
+
bs4
|
3 |
+
lxml
|
4 |
+
fastapi
|
5 |
+
loguru
|
6 |
+
chromadb
|
7 |
+
langchain
|
8 |
+
sentence_transformers
|
9 |
+
InstructorEmbedding
|
10 |
+
sse_starlette
|
11 |
+
dropbox
|
12 |
+
firebase
|
13 |
+
together
|
14 |
+
python-dotenv
|
15 |
+
pytest
|
16 |
+
pytest-env
|
17 |
+
pydantic==1.10.9
|
18 |
+
openapi-schema-pydantic
|
19 |
+
markdown
|
20 |
+
google-generativeai
|
src/baseInfra/dbInterface.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from firebase import firebase
|
2 |
+
import os,json
|
3 |
+
import baseInfra.firebase_handler as fbh
|
4 |
+
|
5 |
+
secret=os.environ['FIREBASE_TOKEN']
|
6 |
+
user_id="131251"
|
7 |
+
auth=firebase.FirebaseAuthentication(secret=secret,email="anubhav77@gmail.com",extra={"id":user_id})
|
8 |
+
fb_url="https://device-1a455.firebaseio.com"
|
9 |
+
fb= firebase.FirebaseApplication(fb_url,auth)
|
10 |
+
base_dir="/users/"+user_id
|
11 |
+
|
12 |
+
class DbInterface:
|
13 |
+
"""
|
14 |
+
A database interface.
|
15 |
+
|
16 |
+
This class provides a dummy implementation of a database interface. It has two methods: `get_task_list()` and `update_task_list()`.
|
17 |
+
"""
|
18 |
+
def __init__(self):
|
19 |
+
self.base_dir=base_dir
|
20 |
+
self.conf_loc="config"
|
21 |
+
self.cache_loc="vs_cache"
|
22 |
+
|
23 |
+
def _set_base_dir(self,sub_dir):
|
24 |
+
"""
|
25 |
+
This internal method is primarily for unit testcases
|
26 |
+
Adds extra subdir to the path
|
27 |
+
"""
|
28 |
+
if sub_dir.startswith("/"):
|
29 |
+
self.base_dir=base_dir+sub_dir
|
30 |
+
else:
|
31 |
+
self.base_dir=base_dir+"/"+sub_dir
|
32 |
+
if self.base_dir.endswith("/"):
|
33 |
+
self.base_dir=self.base_dir[:-1]
|
34 |
+
return self.base_dir
|
35 |
+
|
36 |
+
|
37 |
+
def get_config(self,path):
|
38 |
+
"""
|
39 |
+
Gets the config from the database given path
|
40 |
+
This is used for getting configs of executor, planner or tools
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
The json config object.
|
44 |
+
|
45 |
+
"""
|
46 |
+
print(self.base_dir+"/"+self.conf_loc)
|
47 |
+
print(path)
|
48 |
+
config = fb.get(self.base_dir+"/"+self.conf_loc, path)
|
49 |
+
if config is None:
|
50 |
+
return {}
|
51 |
+
return config
|
52 |
+
|
53 |
+
def get_matching_cache(self,input):
|
54 |
+
"""
|
55 |
+
Gets the matching cache from the database given input
|
56 |
+
This is used for getting similar documents from the vector store
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
The json cache object.
|
60 |
+
|
61 |
+
"""
|
62 |
+
fb_key=fbh.convert_to_firebase_key(input)
|
63 |
+
cache = fb.get(self.base_dir+"/"+self.cache_loc, fb_key)
|
64 |
+
if cache == None:
|
65 |
+
return []
|
66 |
+
else:
|
67 |
+
for item in cache: #ie cache is list of dicts {'input':inStr,'value':cached_items_list}
|
68 |
+
if item['input'] == input:
|
69 |
+
return item['value']
|
70 |
+
|
71 |
+
return []
|
72 |
+
|
73 |
+
def add_to_cache(self,input,value):
|
74 |
+
"""
|
75 |
+
Adds the input and value to the cache
|
76 |
+
This is used for adding documents to the vector store
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
The True if item in cache was updated and false if a new item was added
|
80 |
+
|
81 |
+
"""
|
82 |
+
retVal=False
|
83 |
+
fb_key=fbh.convert_to_firebase_key(input)
|
84 |
+
cache = fb.get(self.base_dir+"/"+self.cache_loc, fb_key)
|
85 |
+
if cache is None:
|
86 |
+
cache = []
|
87 |
+
else:
|
88 |
+
for item in cache: #ie cache is list of dicts {'input':inStr,'value':cached_items_list}
|
89 |
+
if item['input'] == input:
|
90 |
+
item['value']=value
|
91 |
+
retVal=True
|
92 |
+
cache.append({'input':input,'value':value})
|
93 |
+
fb.patch(self.base_dir+"/"+self.cache_loc, {fb_key:cache})
|
94 |
+
return retVal
|
95 |
+
|
96 |
+
|
97 |
+
|
src/baseInfra/dropbox_handler.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import time
|
3 |
+
import dropbox
|
4 |
+
from dropbox.files import WriteMode
|
5 |
+
from dropbox.exceptions import ApiError, AuthError
|
6 |
+
import sys,os
|
7 |
+
import baseInfra.firebase_handler as fbh
|
8 |
+
|
9 |
+
TOKEN=fbh.fb_get("d2_accesstoken")
|
10 |
+
APP_KEY=os.environ['DROPBOX_APP_KEY']
|
11 |
+
APP_SECRET=os.environ['DROPBOX_APP_SECRET']
|
12 |
+
REFRESH_TOKEN=fbh.fb_get("d2_refreshtoken")
|
13 |
+
|
14 |
+
#os.environ['DROP_DIR2']="C:/dockers/chroma/chroma1/"
|
15 |
+
#os.environ['APP_PATH']="/"
|
16 |
+
#print("token::",TOKEN)
|
17 |
+
|
18 |
+
with dropbox.Dropbox(oauth2_access_token=TOKEN,app_key=APP_KEY,app_secret=APP_SECRET,oauth2_refresh_token=REFRESH_TOKEN) as dbx: #,app_key=APP_KEY,app_secret=APP_SECRET,oauth2_refresh_token=REFRESH_TOKEN) as dbx:
|
19 |
+
# Check that the access token is valid
|
20 |
+
try:
|
21 |
+
dbx.users_get_current_account()
|
22 |
+
if (TOKEN != dbx._oauth2_access_token):
|
23 |
+
fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
|
24 |
+
TOKEN=dbx._oauth2_access_token
|
25 |
+
print("dropbox connection ok,",dbx._oauth2_access_token)
|
26 |
+
print(dbx._oauth2_refresh_token)
|
27 |
+
except AuthError:
|
28 |
+
try:
|
29 |
+
dbx.check_and_refresh_access_token()
|
30 |
+
fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
|
31 |
+
print("dropbox connection refreshed and updated",dbx._oauth2_access_token)
|
32 |
+
print(dbx._oauth2_refresh_token)
|
33 |
+
except Exception:
|
34 |
+
sys.exit("ERROR: Invalid access token; try re-generating an "
|
35 |
+
"access token from the app console on the web.")
|
36 |
+
|
37 |
+
def normalizeFilename(filename):
|
38 |
+
while '//' in filename:
|
39 |
+
filename = filename.replace('//', '/')
|
40 |
+
return filename
|
41 |
+
|
42 |
+
def getDropboxFilename(localFilename):
|
43 |
+
""" localFilename is $DROP_DIR2/<subpath>/<filename>"""
|
44 |
+
""" dropboxFilename is $APP_PATH/<subpath>/<filename"""
|
45 |
+
#if not localFilename.startswith(os.environ['DROP_DIR2']):
|
46 |
+
# localFilename=os.environ['DROP_DIR2']+localFilename
|
47 |
+
localFilename=normalizeFilename(localFilename)
|
48 |
+
return normalizeFilename(localFilename.replace(os.environ['DROP_DIR2'],"/",1).replace("/",os.environ['APP_PATH'],1))
|
49 |
+
|
50 |
+
def getLocalFilename(dropboxFilename):
|
51 |
+
""" localFilename is $DROP_DIR2/<subpath>/<filename>"""
|
52 |
+
""" dropboxFilename is $APP_PATH/<subpath>/<filename"""
|
53 |
+
#if not dropboxFilename.startswith(os.environ['APP_PATH']):
|
54 |
+
# dropboxFilename=os.environ['APP_PATH']+dropboxFilename
|
55 |
+
dropboxFilename=normalizeFilename(dropboxFilename)
|
56 |
+
return normalizeFilename(dropboxFilename.replace(os.environ['APP_PATH'],"/",1).replace("/",os.environ['DROP_DIR2'],1))
|
57 |
+
|
58 |
+
def backupFile(localFilename):
|
59 |
+
"""Upload a file.
|
60 |
+
Return the request response, or None in case of error.
|
61 |
+
This will also create directory on dropbox if needed
|
62 |
+
"""
|
63 |
+
global TOKEN
|
64 |
+
localFilename=normalizeFilename(localFilename)
|
65 |
+
dropboxFilename=getDropboxFilename(localFilename)
|
66 |
+
print("backing file ",localFilename," to ",dropboxFilename)
|
67 |
+
mode = dropbox.files.WriteMode.overwrite
|
68 |
+
mtime = os.path.getmtime(localFilename)
|
69 |
+
with open(localFilename, 'rb') as f:
|
70 |
+
data = f.read()
|
71 |
+
try:
|
72 |
+
res = dbx.files_upload(
|
73 |
+
data, dropboxFilename, mode,
|
74 |
+
client_modified=datetime.datetime(*time.gmtime(mtime)[:6]),
|
75 |
+
mute=True)
|
76 |
+
if (TOKEN != dbx._oauth2_access_token):
|
77 |
+
fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
|
78 |
+
TOKEN=dbx._oauth2_access_token
|
79 |
+
print(dbx._oauth2_refresh_token)
|
80 |
+
except dropbox.exceptions.ApiError as err:
|
81 |
+
print('*** API error', err)
|
82 |
+
return None
|
83 |
+
print('uploaded as', res.name.encode('utf8'))
|
84 |
+
return res
|
85 |
+
|
86 |
+
def restoreFile(dropboxFilename):
|
87 |
+
"""Download a file.
|
88 |
+
Return the bytes of the file, or None if it doesn't exist.
|
89 |
+
Will create dir+subdirs if possible
|
90 |
+
"""
|
91 |
+
global TOKEN
|
92 |
+
dropboxFilename=normalizeFilename(dropboxFilename)
|
93 |
+
localFilename=getLocalFilename(dropboxFilename)
|
94 |
+
print("restoring file ",localFilename," from ",dropboxFilename)
|
95 |
+
try:
|
96 |
+
md, res = dbx.files_download(dropboxFilename)
|
97 |
+
if (TOKEN != dbx._oauth2_access_token):
|
98 |
+
fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
|
99 |
+
TOKEN=dbx._oauth2_access_token
|
100 |
+
print(dbx._oauth2_refresh_token)
|
101 |
+
except dropbox.exceptions.HttpError as err:
|
102 |
+
print('*** HTTP error', err)
|
103 |
+
return None
|
104 |
+
data = res.content
|
105 |
+
print(len(data), 'bytes; md:', md)
|
106 |
+
localdir=os.path.dirname(localFilename)
|
107 |
+
if not os.path.exists(localdir):
|
108 |
+
os.makedirs(localdir)
|
109 |
+
with open(localFilename, 'wb') as f:
|
110 |
+
f.write(data)
|
111 |
+
return data
|
112 |
+
|
113 |
+
def backupFolder(localFolder):
|
114 |
+
""" list all files in folder and subfolder and upload them"""
|
115 |
+
print("backup folder called for ",localFolder)
|
116 |
+
if not localFolder.startswith(os.environ['DROP_DIR2']):
|
117 |
+
localFolder=os.environ['DROP_DIR2']+localFolder
|
118 |
+
filenames=[]
|
119 |
+
for (root,dirs,files) in os.walk(localFolder, topdown=True):
|
120 |
+
print(root)
|
121 |
+
for filename in files:
|
122 |
+
filenames.append(root+"/"+filename)
|
123 |
+
print(root+"/"+filename)
|
124 |
+
backupFile(root+"/"+filename)
|
125 |
+
|
126 |
+
|
127 |
+
def restoreFolder(dropboxFolder):
|
128 |
+
""" list all files in dropbox folder and subfolders and restore them"""
|
129 |
+
global TOKEN
|
130 |
+
if not dropboxFolder.startswith(os.environ['APP_PATH']):
|
131 |
+
dropboxFolder=os.environ['APP_PATH']+dropboxFolder
|
132 |
+
try:
|
133 |
+
res=dbx.files_list_folder(dropboxFolder)
|
134 |
+
if (TOKEN != dbx._oauth2_access_token):
|
135 |
+
fbh.fb_update("d2_accesstoken",dbx._oauth2_access_token)
|
136 |
+
TOKEN=dbx._oauth2_access_token
|
137 |
+
print(dbx._oauth2_refresh_token)
|
138 |
+
except dropbox.exceptions.ApiError as err:
|
139 |
+
print('Folder listing failed for', dropboxFolder, '-- assumed empty:', err)
|
140 |
+
return
|
141 |
+
except dropbox.exceptions.AuthError as err1:
|
142 |
+
print('Folder listing failed for', dropboxFolder, '-- assumed empty:', err1)
|
143 |
+
return
|
144 |
+
for entry in res.entries:
|
145 |
+
if (isinstance(entry, dropbox.files.FileMetadata)):
|
146 |
+
restoreFile(entry.path_display)
|
147 |
+
else:
|
148 |
+
try:
|
149 |
+
restoreFolder(entry.path_display)
|
150 |
+
except Exception:
|
151 |
+
print("Error restoring folder,",entry.path_display)
|
152 |
+
print(entry.path_display)
|
src/baseInfra/firebase_handler.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
load_dotenv()
|
3 |
+
from firebase import firebase
|
4 |
+
import os
|
5 |
+
|
6 |
+
secret=os.environ['FIREBASE_TOKEN']
|
7 |
+
user_id="db" #this file is for dropbox token management through firebase only
|
8 |
+
auth=firebase.FirebaseAuthentication(secret=secret,email="anubhav77@gmail.com",extra={"id":user_id})
|
9 |
+
fb_url="https://device-1a455.firebaseio.com"
|
10 |
+
fb= firebase.FirebaseApplication(fb_url,auth)
|
11 |
+
base_dir="/users/"+user_id
|
12 |
+
|
13 |
+
def fb_get(item):
|
14 |
+
""" item need to be in format d2_refreshtoken same name will be present in the firebase
|
15 |
+
d2 means second dropbox instance or app.
|
16 |
+
"""
|
17 |
+
return fb.get(base_dir,item)
|
18 |
+
|
19 |
+
def fb_update(item,value):
|
20 |
+
""" item need to be in format d2_refreshtoken same name will be present in the firebase
|
21 |
+
d2 means second dropbox instance or app.
|
22 |
+
"""
|
23 |
+
return fb.patch(base_dir,{item:value})
|
24 |
+
|
25 |
+
import re
|
26 |
+
|
27 |
+
def convert_to_firebase_key(inStr):
|
28 |
+
"""Converts a string to a Firebase key.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
inStr: The string to convert.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
A Firebase key, or None if the inStr is not a valid Firebase key.
|
35 |
+
"""
|
36 |
+
|
37 |
+
# Firebase keys must be between 3 and 128 characters long.
|
38 |
+
# They must start with a letter, number, or underscore.
|
39 |
+
# They can only contain letters, numbers, underscores, hyphens, and periods.
|
40 |
+
# They cannot contain spaces or other special characters.
|
41 |
+
|
42 |
+
# Convert the inStr to lowercase.
|
43 |
+
outStr = inStr.lower()
|
44 |
+
|
45 |
+
# Replace all spaces with underscores.
|
46 |
+
outStr = outStr.replace(' ', '_')
|
47 |
+
|
48 |
+
# Remove any leading or trailing underscores.
|
49 |
+
outStr = outStr.lstrip('_').rstrip('_')
|
50 |
+
|
51 |
+
# Remove any character not matching firebase_key_Regex.
|
52 |
+
outStr=re.sub(r'[^\w\-.]', '', outStr)
|
53 |
+
|
54 |
+
|
55 |
+
firebase_key_regex = re.compile(r'^[a-zA-Z0-9_][a-zA-Z0-9-_.]*$')
|
56 |
+
# The following string values are not allowed for Firebase keys:
|
57 |
+
# - "firebase"
|
58 |
+
# - "google"
|
59 |
+
# - "android"
|
60 |
+
# - "ios"
|
61 |
+
# - "web"
|
62 |
+
# - "console"
|
63 |
+
# - "auth"
|
64 |
+
# - "database"
|
65 |
+
# - "storage"
|
66 |
+
# - "hosting"
|
67 |
+
# - "functions"
|
68 |
+
# - "firestore"
|
69 |
+
# - "realtime-database"
|
70 |
+
# - "remote-config"
|
71 |
+
# - "analytics"
|
72 |
+
# - "crashlytics"
|
73 |
+
# - "performance-monitoring"
|
74 |
+
# - "test-lab"
|
75 |
+
# - "cloud-messaging"
|
76 |
+
# - "dynamic-links"
|
77 |
+
# - "identity-toolkit"
|
78 |
+
# - "cloud-functions"
|
79 |
+
# - "cloud-firestore"
|
80 |
+
# - "cloud-storage"
|
81 |
+
# - "cloud-hosting"
|
82 |
+
# - "cloud-functions-v2"
|
83 |
+
# - "cloud-firestore-v2"
|
84 |
+
# - "cloud-storage-v2"
|
85 |
+
# - "cloud-hosting-v2"
|
86 |
+
# - "cloud-functions-v3"
|
87 |
+
# - "cloud-firestore-v3"
|
88 |
+
# - "cloud-storage-v3"
|
89 |
+
# - "cloud-hosting-v3"
|
90 |
+
|
91 |
+
if not firebase_key_regex.match(inStr) or inStr in ['firebase', 'google', 'android', 'ios', 'web', 'console', 'auth', 'database', 'storage', 'hosting', 'functions', 'firestore', 'realtime-database', 'remote-config', 'analytics', 'crashlytics', 'performance-monitoring', 'test-lab', 'cloud-messaging', 'dynamic-links', 'identity-toolkit', 'cloud-functions', 'cloud-firestore', 'cloud-storage', 'cloud-hosting', 'cloud-functions-v2', 'cloud-firestore-v2', 'cloud-storage-v2', 'cloud-hosting-v2', 'cloud-functions-v3', 'cloud-firestore-v3', 'cloud-storage-v3', 'cloud-hosting-v3']:
|
92 |
+
return "fb_"+outStr
|
93 |
+
|
94 |
+
if len(outStr) > 120:
|
95 |
+
outStr = "fbx_"+outStr[:120]
|
96 |
+
|
97 |
+
return outStr
|
src/indexer.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.vectorstores import Chroma
|
2 |
+
from chromadb.api.fastapi import requests
|
3 |
+
from langchain.schema import Document
|
4 |
+
from langchain.chains import RetrievalQA
|
5 |
+
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
6 |
+
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
7 |
+
from langchain.chains.query_constructor.base import AttributeInfo
|
8 |
+
from llm.llmFactory import LLMFactory
|
9 |
+
|
10 |
+
model_name = "BAAI/bge-base-en"
|
11 |
+
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
|
12 |
+
|
13 |
+
embedding = HuggingFaceBgeEmbeddings(
|
14 |
+
model_name=model_name,
|
15 |
+
model_kwargs={'device': 'cpu'},
|
16 |
+
encode_kwargs=encode_kwargs
|
17 |
+
)
|
18 |
+
|
19 |
+
persist_directory = 'db'
|
20 |
+
docs = [
|
21 |
+
Document(
|
22 |
+
page_content="Complex, layered, rich red with dark fruit flavors",
|
23 |
+
metadata={"name":"Opus One", "year": 2018, "rating": 96, "grape": "Cabernet Sauvignon", "color":"red", "country":"USA"},
|
24 |
+
),
|
25 |
+
Document(
|
26 |
+
page_content="Luxurious, sweet wine with flavors of honey, apricot, and peach",
|
27 |
+
metadata={"name":"Château d'Yquem", "year": 2015, "rating": 98, "grape": "Sémillon", "color":"white", "country":"France"},
|
28 |
+
),
|
29 |
+
Document(
|
30 |
+
page_content="Full-bodied red with notes of black fruit and spice",
|
31 |
+
metadata={"name":"Penfolds Grange", "year": 2017, "rating": 97, "grape": "Shiraz", "color":"red", "country":"Australia"},
|
32 |
+
),
|
33 |
+
Document(
|
34 |
+
page_content="Elegant, balanced red with herbal and berry nuances",
|
35 |
+
metadata={"name":"Sassicaia", "year": 2016, "rating": 95, "grape": "Cabernet Franc", "color":"red", "country":"Italy"},
|
36 |
+
),
|
37 |
+
Document(
|
38 |
+
page_content="Highly sought-after Pinot Noir with red fruit and earthy notes",
|
39 |
+
metadata={"name":"Domaine de la Romanée-Conti", "year": 2018, "rating": 100, "grape": "Pinot Noir", "color":"red", "country":"France"},
|
40 |
+
),
|
41 |
+
Document(
|
42 |
+
page_content="Crisp white with tropical fruit and citrus flavors",
|
43 |
+
metadata={"name":"Cloudy Bay", "year": 2021, "rating": 92, "grape": "Sauvignon Blanc", "color":"white", "country":"New Zealand"},
|
44 |
+
),
|
45 |
+
Document(
|
46 |
+
page_content="Rich, complex Champagne with notes of brioche and citrus",
|
47 |
+
metadata={"name":"Krug Grande Cuvée", "year": 2010, "rating": 93, "grape": "Chardonnay blend", "color":"sparkling", "country":"New Zealand"},
|
48 |
+
),
|
49 |
+
Document(
|
50 |
+
page_content="Intense, dark fruit flavors with hints of chocolate",
|
51 |
+
metadata={"name":"Caymus Special Selection", "year": 2018, "rating": 96, "grape": "Cabernet Sauvignon", "color":"red", "country":"USA"},
|
52 |
+
),
|
53 |
+
Document(
|
54 |
+
page_content="Exotic, aromatic white with stone fruit and floral notes",
|
55 |
+
metadata={"name":"Jermann Vintage Tunina", "year": 2020, "rating": 91, "grape": "Sauvignon Blanc blend", "color":"white", "country":"Italy"},
|
56 |
+
),
|
57 |
+
]
|
58 |
+
|
59 |
+
vectorstore = Chroma.from_documents(documents=docs,
|
60 |
+
embedding=embedding,
|
61 |
+
persist_directory=persist_directory)
|
62 |
+
|
63 |
+
metadata_field_info = [
|
64 |
+
AttributeInfo(
|
65 |
+
name="grape",
|
66 |
+
description="The grape used to make the wine",
|
67 |
+
type="string or list[string]",
|
68 |
+
),
|
69 |
+
AttributeInfo(
|
70 |
+
name="name",
|
71 |
+
description="The name of the wine",
|
72 |
+
type="string or list[string]",
|
73 |
+
),
|
74 |
+
AttributeInfo(
|
75 |
+
name="color",
|
76 |
+
description="The color of the wine",
|
77 |
+
type="string or list[string]",
|
78 |
+
),
|
79 |
+
AttributeInfo(
|
80 |
+
name="year",
|
81 |
+
description="The year the wine was released",
|
82 |
+
type="integer",
|
83 |
+
),
|
84 |
+
AttributeInfo(
|
85 |
+
name="country",
|
86 |
+
description="The name of the country the wine comes from",
|
87 |
+
type="string",
|
88 |
+
),
|
89 |
+
AttributeInfo(
|
90 |
+
name="rating", description="The Robert Parker rating for the wine 0-100", type="integer" #float
|
91 |
+
),
|
92 |
+
]
|
93 |
+
document_content_description = "Brief description of the wine"
|
94 |
+
lf=LLMFactory()
|
95 |
+
llm=lf.get_llm("executor2")
|
96 |
+
|
97 |
+
retriever = SelfQueryRetriever.from_llm(
|
98 |
+
llm,
|
99 |
+
vectorstore,
|
100 |
+
document_content_description,
|
101 |
+
metadata_field_info,
|
102 |
+
verbose=True
|
103 |
+
)
|
104 |
+
|
105 |
+
def getRelevantDocs(query:str):
|
106 |
+
return retriever.get_relevant_documents(query)
|
src/llm/hostedLLM.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Mapping, Optional, Dict
|
2 |
+
from pydantic import Extra, Field #, root_validator, model_validator
|
3 |
+
import os,json
|
4 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
5 |
+
from langchain.llms.base import LLM
|
6 |
+
import requests
|
7 |
+
|
8 |
+
|
9 |
+
class HostedLLM(LLM):
|
10 |
+
"""
|
11 |
+
Hosted LLMs in huggingface spaces with fastAPI. Interface is primarily rest call with huggingface token
|
12 |
+
|
13 |
+
Attributes:
|
14 |
+
url: is the url of the endpoint
|
15 |
+
http_method: which http_method need to be invoked [get or post]
|
16 |
+
model_name: which model is being hosted
|
17 |
+
temperature: temperature between 0 to 1
|
18 |
+
max_tokens: amount of output to generate, 512 by default
|
19 |
+
api_token: api_token to be passed for bearer authorization. Defaults to huggingface_api enviorment variable.
|
20 |
+
verbose: for extra logging
|
21 |
+
"""
|
22 |
+
url: str = ""
|
23 |
+
http_method: Optional[str] = "post"
|
24 |
+
model_name: Optional[str] = "bard"
|
25 |
+
api_token: Optional[str] = os.environ["HUGGINGFACE_API"]
|
26 |
+
temperature: float = 0.7
|
27 |
+
max_tokens: int = 512
|
28 |
+
verbose: Optional[bool] = False
|
29 |
+
class Config:
|
30 |
+
extra = Extra.forbid
|
31 |
+
|
32 |
+
#@model_validator(mode="after")
|
33 |
+
#def validate_environment(cls, values: Dict) -> Dict:
|
34 |
+
# if values["http_method"].strip() == "GET" or values["http_method"].strip() == "get":
|
35 |
+
# values["http_method"]="get"
|
36 |
+
# else:
|
37 |
+
# values["http_method"]="post"
|
38 |
+
# if values["api_token"] == "":
|
39 |
+
# values["api_token"] = os.environ["HUGGINGFACE_API"]
|
40 |
+
#
|
41 |
+
# return values
|
42 |
+
|
43 |
+
@property
|
44 |
+
def _llm_type(self) -> str:
|
45 |
+
return "text2text-generation"
|
46 |
+
|
47 |
+
def _call(
|
48 |
+
self,
|
49 |
+
prompt: str,
|
50 |
+
stop: Optional[List[str]] = None,
|
51 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
52 |
+
) -> str:
|
53 |
+
if run_manager:
|
54 |
+
run_manager.on_text([prompt])
|
55 |
+
#messages={"messages":[{"role":"user","content":prompt}]}
|
56 |
+
prompt={"prompt":prompt}
|
57 |
+
headers = {
|
58 |
+
"Authorization": f"Bearer {self.api_token}",
|
59 |
+
"Content-Type": "application/json",
|
60 |
+
}
|
61 |
+
if(self.http_method=="post"):
|
62 |
+
response=requests.post(self.url,json=prompt,headers=headers)
|
63 |
+
else:
|
64 |
+
response=requests.get(self.url,json=prompt,headers=headers)
|
65 |
+
val=json.loads(response.text)['content']
|
66 |
+
if run_manager:
|
67 |
+
run_manager.on_llm_end(val)
|
68 |
+
return val
|
69 |
+
|
70 |
+
@property
|
71 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
72 |
+
"""Get the identifying parameters."""
|
73 |
+
return {"name": self.model_name, "type": "hosted"}
|
74 |
+
|
75 |
+
def extractJson(self,val:str) -> Any:
|
76 |
+
"""Helper function to extract json from this LLMs output"""
|
77 |
+
#This is assuming the json is the first item within ````
|
78 |
+
#my super LLM will sometime send the json directly
|
79 |
+
try:
|
80 |
+
v3=val.replace("\n","").replace("\r","")
|
81 |
+
v4=json.loads(v3)
|
82 |
+
except:
|
83 |
+
v2=val.replace("```json","```").split("```")[1]
|
84 |
+
v3=v2.replace("\n","").replace("\r","")
|
85 |
+
v4=json.loads(v3)
|
86 |
+
return v4
|
87 |
+
|
88 |
+
def extractPython(self,val:str) -> Any:
|
89 |
+
"""Helper function to extract python from this LLMs output"""
|
90 |
+
#This is assuming the python is the first item within ````
|
91 |
+
v2=val.replace("```python","```").split("```")[1]
|
92 |
+
return v2
|
src/llm/llmFactory.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from baseInfra.dbInterface import DbInterface
|
3 |
+
from llm.hostedLLM import HostedLLM
|
4 |
+
from llm.togetherLLM import TogetherLLM
|
5 |
+
from llm.palmLLM import PalmLLM
|
6 |
+
|
7 |
+
|
8 |
+
class LLMFactory:
|
9 |
+
"""
|
10 |
+
Factory class for creating LLM objects.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
"""
|
15 |
+
Constructor for the LLMFactory class.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
db_interface: The DBInterface object to use for getting LLM configs.
|
19 |
+
"""
|
20 |
+
self._db_interface = DbInterface()
|
21 |
+
|
22 |
+
def get_llm(self, llm_path: str) -> object:
|
23 |
+
"""
|
24 |
+
Gets an LLM object of the specified type.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
llm_path: The path to the LLM config.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
The LLM object.
|
31 |
+
"""
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
try:
|
34 |
+
config = self._db_interface.get_config(llm_path)
|
35 |
+
logger.debug(llm_path)
|
36 |
+
logger.debug(config)
|
37 |
+
llm_type = config["llm_type"]
|
38 |
+
llm_config=config["llm_config"]
|
39 |
+
except Exception as ex:
|
40 |
+
logger.exception("Exception in getLLM")
|
41 |
+
logger.exception(ex)
|
42 |
+
config={}
|
43 |
+
llm_type=""
|
44 |
+
llm_config={}
|
45 |
+
|
46 |
+
if llm_type == "hostedLLM":
|
47 |
+
return HostedLLM(**llm_config)
|
48 |
+
elif llm_type == "togetherLLM":
|
49 |
+
return TogetherLLM(**llm_config)
|
50 |
+
elif llm_type == "palmLLM":
|
51 |
+
return PalmLLM(**llm_config)
|
52 |
+
else:
|
53 |
+
logger.error(f"Invalid LLM type: {llm_type}")
|
54 |
+
raise ValueError(f"Invalid LLM type: {llm_type}")
|
55 |
+
|
src/llm/palmLLM.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Mapping, Optional, Dict
|
2 |
+
from pydantic import Extra, Field #, root_validator, model_validator
|
3 |
+
import os,json
|
4 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
5 |
+
from langchain.llms.base import LLM
|
6 |
+
import google.generativeai as palm
|
7 |
+
from google.generativeai import types
|
8 |
+
import ast
|
9 |
+
#from langchain.llms import GooglePalm
|
10 |
+
import requests
|
11 |
+
|
12 |
+
|
13 |
+
class PalmLLM(LLM):
|
14 |
+
|
15 |
+
model_name: str = "text-bison-001"
|
16 |
+
temperature: float = 0
|
17 |
+
max_tokens: int = 2048
|
18 |
+
stop: Optional[List] = []
|
19 |
+
prev_prompt: Optional[str]=""
|
20 |
+
prev_stop: Optional[str]=""
|
21 |
+
prev_run_manager:Optional[Any]=None
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
**kwargs
|
26 |
+
):
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
palm.configure()
|
29 |
+
#self.model = palm.Text2Text(self.model_name)
|
30 |
+
|
31 |
+
@property
|
32 |
+
def _llm_type(self) -> str:
|
33 |
+
return "text2text-generation"
|
34 |
+
|
35 |
+
def _call(
|
36 |
+
self,
|
37 |
+
prompt: str,
|
38 |
+
stop: Optional[List[str]] = None,
|
39 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
40 |
+
) -> str:
|
41 |
+
self.prev_prompt=prompt
|
42 |
+
self.prev_stop=stop
|
43 |
+
self.prev_run_manager=run_manager
|
44 |
+
if stop == None:
|
45 |
+
stop=self.stop
|
46 |
+
text=palm.generate_text(prompt=prompt,stop_sequences=self.stop,
|
47 |
+
temperature=self.temperature, max_output_tokens=self.max_tokens,
|
48 |
+
safety_settings=[{"category":0,"threshold":4},
|
49 |
+
{"category":1,"threshold":4},
|
50 |
+
{"category":2,"threshold":4},
|
51 |
+
{"category":3,"threshold":4},
|
52 |
+
{"category":4,"threshold":4},
|
53 |
+
{"category":5,"threshold":4},
|
54 |
+
{"category":6, "threshold":4}]
|
55 |
+
)
|
56 |
+
print("Response from palm",text)
|
57 |
+
val=text.result
|
58 |
+
if run_manager:
|
59 |
+
run_manager.on_llm_end(val)
|
60 |
+
return val
|
61 |
+
|
62 |
+
@property
|
63 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
64 |
+
"""Get the identifying parameters."""
|
65 |
+
return {"name": self.model_name, "type": "palm"}
|
66 |
+
|
67 |
+
def extractJson(self,val:str) -> Any:
|
68 |
+
"""Helper function to extract json from this LLMs output"""
|
69 |
+
#This is assuming the json is the first item within ````
|
70 |
+
# palm is responding always with ```json and ending with ```, however sometimes response is not complete
|
71 |
+
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
|
72 |
+
try:
|
73 |
+
count=0
|
74 |
+
while val.startswith("```json") and not val.endswith("```") and count<7:
|
75 |
+
val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager)
|
76 |
+
count+=1
|
77 |
+
v2=val.replace("```json","```").split("```")[1]
|
78 |
+
#v3=v2.replace("\n","").replace("\r","").replace("'","\"")
|
79 |
+
v3=json.dumps(ast.literal_eval(v2))
|
80 |
+
v4=json.loads(v3)
|
81 |
+
except:
|
82 |
+
v2=val.replace("\n","").replace("\r","")
|
83 |
+
v3=json.dumps(ast.literal_eval(val))
|
84 |
+
#v3=v2.replace("'","\"")
|
85 |
+
v4=json.loads(v3)
|
86 |
+
#v4=json.loads(v2)
|
87 |
+
return v4
|
88 |
+
|
89 |
+
def extractPython(self,val:str) -> Any:
|
90 |
+
"""Helper function to extract python from this LLMs output"""
|
91 |
+
#This is assuming the python is the first item within ````
|
92 |
+
v2=val.replace("```python","```").split("```")[1]
|
93 |
+
return v2
|
src/llm/togetherLLM.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import together
|
2 |
+
import os
|
3 |
+
import logging,json
|
4 |
+
from typing import Any, Dict, List, Mapping, Optional
|
5 |
+
|
6 |
+
from pydantic import Extra, Field #, root_validator, model_validator
|
7 |
+
|
8 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
9 |
+
from langchain.llms.base import LLM
|
10 |
+
from langchain.llms.utils import enforce_stop_tokens
|
11 |
+
from langchain.utils import get_from_dict_or_env
|
12 |
+
|
13 |
+
class TogetherLLM(LLM):
|
14 |
+
"""Together large language models."""
|
15 |
+
|
16 |
+
model_name: str = "togethercomputer/llama-2-70b-chat"
|
17 |
+
"""model endpoint to use"""
|
18 |
+
|
19 |
+
together_api_key: str = os.environ["TOGETHER_API_KEY"]
|
20 |
+
"""Together API key"""
|
21 |
+
|
22 |
+
temperature: float = 0.7
|
23 |
+
"""What sampling temperature to use."""
|
24 |
+
|
25 |
+
max_tokens: int = 512
|
26 |
+
"""The maximum number of tokens to generate in the completion."""
|
27 |
+
|
28 |
+
class Config:
|
29 |
+
extra = Extra.forbid
|
30 |
+
|
31 |
+
#@model_validator(mode="after")
|
32 |
+
#def validate_environment(cls, values: Dict) -> Dict:
|
33 |
+
# """Validate that the API key is set."""
|
34 |
+
# api_key = get_from_dict_or_env(
|
35 |
+
# values, "together_api_key", "TOGETHER_API_KEY"
|
36 |
+
# )
|
37 |
+
# values["together_api_key"] = api_key
|
38 |
+
# return values
|
39 |
+
|
40 |
+
@property
|
41 |
+
def _llm_type(self) -> str:
|
42 |
+
"""Return type of LLM."""
|
43 |
+
return "together"
|
44 |
+
|
45 |
+
def _call(
|
46 |
+
self,
|
47 |
+
prompt: str,
|
48 |
+
**kwargs: Any,
|
49 |
+
) -> str:
|
50 |
+
"""Call to Together endpoint."""
|
51 |
+
together.api_key = self.together_api_key
|
52 |
+
output = together.Complete.create(prompt,
|
53 |
+
model=self.model_name,
|
54 |
+
max_tokens=self.max_tokens,
|
55 |
+
temperature=self.temperature,
|
56 |
+
)
|
57 |
+
text = output['output']['choices'][0]['text']
|
58 |
+
return text
|
59 |
+
|
60 |
+
def extractJson(self,val:str) -> Any:
|
61 |
+
"""Helper function to extract json from this LLMs output"""
|
62 |
+
#This is assuming the json is the first item within ````
|
63 |
+
v2=val.replace("```json","```").split("```")[1]
|
64 |
+
v3=v2.replace("\n","").replace("\r","")
|
65 |
+
v4=json.loads(v3)
|
66 |
+
return v4
|
67 |
+
|
68 |
+
def extractPython(self,val:str) -> Any:
|
69 |
+
"""Helper function to extract python from this LLMs output"""
|
70 |
+
#This is assuming the python is the first item within ````
|
71 |
+
v2=val.replace("```python","```").split("```")[1]
|
72 |
+
return v2
|