paulbricman commited on
Commit
0ac7b31
1 Parent(s): bf093f4

fix: closes #19

Browse files
Files changed (2) hide show
  1. backend/main.py +27 -25
  2. backend/security.py +1 -2
backend/main.py CHANGED
@@ -1,10 +1,11 @@
1
- from fastapi import FastAPI, Request, Header
2
  from security import auth
3
  from util import find, rank, save, get_authorized_thoughts, remove, dump
4
  from sentence_transformers import SentenceTransformer
5
  from fastapi.datastructures import UploadFile
6
  from fastapi import FastAPI, File, Form
7
  from fastapi.responses import FileResponse, ORJSONResponse
 
8
  from pathlib import Path
9
  from microverses import create_microverse, remove_microverse, list_microverses
10
  from slowapi import Limiter, _rate_limit_exceeded_handler
@@ -13,6 +14,7 @@ from slowapi.middleware import SlowAPIMiddleware
13
  from slowapi.errors import RateLimitExceeded
14
 
15
 
 
16
  limiter = Limiter(key_func=get_remote_address, default_limits=['30/minute'])
17
  app = FastAPI()
18
  app.state.limiter = limiter
@@ -33,7 +35,7 @@ async def find_text_handler(
33
  return_embeddings: bool = False,
34
  silent: bool = False,
35
  request: Request = None,
36
- authorization: str = Header(None)
37
  ):
38
  return find(
39
  'text',
@@ -42,7 +44,7 @@ async def find_text_handler(
42
  activation,
43
  noise,
44
  return_embeddings,
45
- auth(authorization),
46
  text_encoder,
47
  text_image_encoder,
48
  silent
@@ -58,7 +60,7 @@ async def find_image_handler(
58
  return_embeddings: bool = Form(False),
59
  silent: bool = Form(False),
60
  request: Request = None,
61
- authorization: str = Header(None)
62
  ):
63
  query = await query.read()
64
  return find(
@@ -68,7 +70,7 @@ async def find_image_handler(
68
  activation,
69
  noise,
70
  return_embeddings,
71
- auth(authorization),
72
  text_encoder,
73
  text_image_encoder,
74
  silent
@@ -76,59 +78,59 @@ async def find_image_handler(
76
 
77
 
78
  @app.get('/save')
79
- async def save_text_handler(query: str, request: Request, authorization: str = Header(None)):
80
- return save('text', query, auth(authorization),
81
  text_encoder, text_image_encoder)
82
 
83
 
84
  @app.post('/save')
85
- async def save_image_handler(query: UploadFile = File(...), request: Request = None, authorization: str = Header(None)):
86
  query = await query.read()
87
- results = save('image', query, auth(authorization),
88
  text_encoder, text_image_encoder)
89
  return results
90
 
91
 
92
  @app.get('/remove')
93
- async def remove_handler(filename: str, request: Request, authorization: str = Header(None)):
94
- return remove(auth(authorization), filename)
95
 
96
 
97
  @app.get('/dump')
98
- async def save_text_handler(request: Request, authorization: str = Header(None)):
99
- return dump(auth(authorization))
100
 
101
 
102
  @app.get('/static')
103
  @limiter.limit("200/minute")
104
- async def static_handler(filename: str, request: Request, authorization: str = Header(None)):
105
  knowledge_base_path = Path('..') / 'knowledge'
106
- thoughts = get_authorized_thoughts(auth(authorization))
107
  if filename in [e['filename'] for e in thoughts]:
108
  return FileResponse(knowledge_base_path / filename)
109
 
110
 
111
  @app.get('/microverse/create')
112
- async def microverse_create_handler(query: str, request: Request, authorization: str = Header(None)):
113
- return create_microverse('text', query, auth(authorization), text_encoder, text_image_encoder)
114
 
115
 
116
  @app.post('/microverse/create')
117
- async def microverse_create_handler(query: UploadFile = File(...), request: Request = None, authorization: str = Header(None)):
118
  query = await query.read()
119
- return create_microverse('image', query, auth(authorization), text_encoder, text_image_encoder)
120
 
121
 
122
  @app.get('/microverse/remove')
123
- async def microverse_remove_handler(microverse: str, request: Request, authorization: str = Header(None)):
124
- return remove_microverse(auth(authorization), microverse)
125
 
126
 
127
  @app.get('/microverse/list')
128
- async def microverse_list_handler(request: Request, authorization: str = Header(None)):
129
- return list_microverses(auth(authorization))
130
 
131
 
132
  @app.get('/custodian/check')
133
- async def check_custodian(request: Request, authorization: str = Header(None)):
134
- return auth(authorization)
1
+ from fastapi import Depends, FastAPI, Request, Header
2
  from security import auth
3
  from util import find, rank, save, get_authorized_thoughts, remove, dump
4
  from sentence_transformers import SentenceTransformer
5
  from fastapi.datastructures import UploadFile
6
  from fastapi import FastAPI, File, Form
7
  from fastapi.responses import FileResponse, ORJSONResponse
8
+ from fastapi.security import HTTPBearer, HTTPBasicCredentials
9
  from pathlib import Path
10
  from microverses import create_microverse, remove_microverse, list_microverses
11
  from slowapi import Limiter, _rate_limit_exceeded_handler
14
  from slowapi.errors import RateLimitExceeded
15
 
16
 
17
+ security = HTTPBearer()
18
  limiter = Limiter(key_func=get_remote_address, default_limits=['30/minute'])
19
  app = FastAPI()
20
  app.state.limiter = limiter
35
  return_embeddings: bool = False,
36
  silent: bool = False,
37
  request: Request = None,
38
+ authorization: HTTPBasicCredentials = Depends(security)
39
  ):
40
  return find(
41
  'text',
44
  activation,
45
  noise,
46
  return_embeddings,
47
+ auth(authorization.credentials),
48
  text_encoder,
49
  text_image_encoder,
50
  silent
60
  return_embeddings: bool = Form(False),
61
  silent: bool = Form(False),
62
  request: Request = None,
63
+ authorization: HTTPBasicCredentials = Depends(security)
64
  ):
65
  query = await query.read()
66
  return find(
70
  activation,
71
  noise,
72
  return_embeddings,
73
+ auth(authorization.credentials),
74
  text_encoder,
75
  text_image_encoder,
76
  silent
78
 
79
 
80
  @app.get('/save')
81
+ async def save_text_handler(query: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
82
+ return save('text', query, auth(authorization.credentials),
83
  text_encoder, text_image_encoder)
84
 
85
 
86
  @app.post('/save')
87
+ async def save_image_handler(query: UploadFile = File(...), request: Request = None, authorization: HTTPBasicCredentials = Depends(security)):
88
  query = await query.read()
89
+ results = save('image', query, auth(authorization.credentials),
90
  text_encoder, text_image_encoder)
91
  return results
92
 
93
 
94
  @app.get('/remove')
95
+ async def remove_handler(filename: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
96
+ return remove(auth(authorization.credentials), filename)
97
 
98
 
99
  @app.get('/dump')
100
+ async def save_text_handler(request: Request, authorization: HTTPBasicCredentials = Depends(security)):
101
+ return dump(auth(authorization.credentials))
102
 
103
 
104
  @app.get('/static')
105
  @limiter.limit("200/minute")
106
+ async def static_handler(filename: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
107
  knowledge_base_path = Path('..') / 'knowledge'
108
+ thoughts = get_authorized_thoughts(auth(authorization.credentials))
109
  if filename in [e['filename'] for e in thoughts]:
110
  return FileResponse(knowledge_base_path / filename)
111
 
112
 
113
  @app.get('/microverse/create')
114
+ async def microverse_create_handler(query: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
115
+ return create_microverse('text', query, auth(authorization.credentials), text_encoder, text_image_encoder)
116
 
117
 
118
  @app.post('/microverse/create')
119
+ async def microverse_create_handler(query: UploadFile = File(...), request: Request = None, authorization: HTTPBasicCredentials = Depends(security)):
120
  query = await query.read()
121
+ return create_microverse('image', query, auth(authorization.credentials), text_encoder, text_image_encoder)
122
 
123
 
124
  @app.get('/microverse/remove')
125
+ async def microverse_remove_handler(microverse: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
126
+ return remove_microverse(auth(authorization.credentials), microverse)
127
 
128
 
129
  @app.get('/microverse/list')
130
+ async def microverse_list_handler(request: Request, authorization: HTTPBasicCredentials = Depends(security)):
131
+ return list_microverses(auth(authorization.credentials))
132
 
133
 
134
  @app.get('/custodian/check')
135
+ async def check_custodian(request: Request, authorization: HTTPBasicCredentials = Depends(security)):
136
+ return auth(authorization.credentials)
backend/security.py CHANGED
@@ -3,12 +3,11 @@ import json
3
 
4
 
5
  def auth(token):
6
- if token == None or not token.startswith('Bearer '):
7
  return {
8
  'custodian': False
9
  }
10
 
11
- token = token.replace('Bearer ', '')
12
  path = Path('..') / 'knowledge' / 'records.json'
13
 
14
  if not path.exists():
3
 
4
 
5
  def auth(token):
6
+ if not token:
7
  return {
8
  'custodian': False
9
  }
10
 
 
11
  path = Path('..') / 'knowledge' / 'records.json'
12
 
13
  if not path.exists():