Tony Shepherd commited on
Commit
a5a6006
1 Parent(s): aa2cf7d

added http errors

Browse files
Files changed (4) hide show
  1. api_contract.py +54 -0
  2. error_handling.py +8 -0
  3. main.py +24 -9
  4. payload.py +1 -3
api_contract.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from error_handling import ErrorCodes
2
+
3
+ def get_error_codes_description():
4
+ description = "| code | description |\n| - | - |\n"
5
+ error_codes_description = []
6
+ for error_code in ErrorCodes:
7
+ error_codes_description.append(
8
+ f"| {error_code.value[0]} | {error_code.value[1]} |\n"
9
+ )
10
+ list.sort(error_codes_description)
11
+ description = f"{description}{''.join(error_codes_description)}"
12
+ return description
13
+
14
+
15
+ components_dict = {
16
+ "schemas": {
17
+ "ErrorMessage": {
18
+ "required": ["errors"],
19
+ "description": get_error_codes_description(),
20
+ "properties": {
21
+ "errors": {
22
+ "type": "array",
23
+ "title": "Error array",
24
+ "items": {
25
+ "type": "object",
26
+ "required": ["code", "message"],
27
+ "properties": {
28
+ "code": {"type": "integer", "example": 400},
29
+ "message": {
30
+ "type": "string",
31
+ "example": "'input_text' is a required property",
32
+ "maxLength": 256,
33
+ },
34
+ },
35
+ },
36
+ }
37
+ },
38
+ }
39
+ },
40
+ "paths": {
41
+ "/api/generatel-language": {
42
+ "parameters": [
43
+ {
44
+ "name": "petId",
45
+ "in": "path",
46
+ "schema": {
47
+ "type": "integer",
48
+ "format": "int64"
49
+ }
50
+ }
51
+ ]
52
+ }
53
+ }
54
+ }
error_handling.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from enum import Flag
2
+
3
+ class ErrorCodes(Flag):
4
+ INTERNAL_SERVER_ERROR = (500, "internal server error")
5
+ REQUEST_VALIDATION_ERROR = (400, "request validation error")
6
+ JWT_VALIDATION_ERROR = (401, "jwt validation error")
7
+ TEXT_VALIDATION_ERROR = (490, "text validation error")
8
+ OTHER_ERROR = (999, "other error")
main.py CHANGED
@@ -1,22 +1,37 @@
1
  from transformers import pipeline
2
- from fastapi import FastAPI
3
  from payload import SomeText
4
-
5
-
6
 
7
  app = FastAPI()
8
 
 
 
 
 
 
 
 
 
 
9
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
10
 
11
  @app.get("/")
12
  def read_root():
13
- return {'detail': 'API seems to be working'}
14
 
15
  @app.get("/api/check-heartbeat")
16
  def get_heartbeat():
17
- return {"Status": "Running. Try out the endpoints in swagger"}
 
 
 
 
 
 
 
 
18
 
19
- @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=SomeText)
20
- def inference(input_prompt: SomeText):
21
- output = pipe_flan(input_prompt)
22
- return {"output": output[0]["generated_text"]}
 
1
  from transformers import pipeline
2
+ from fastapi import FastAPI, Request, HTTPException
3
  from payload import SomeText
4
+ from api_contract import components_dict as api_components
5
+ from error_handling import ErrorCodes
6
 
7
  app = FastAPI()
8
 
9
+ app=FastAPI(title="Huggingface Gen LLM gest",
10
+ version="1.0",
11
+ debug=True,
12
+ components=api_components,
13
+ swagger_ui_bundle_js= "//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js",
14
+ swagger_ui_standalone_preset_js= "//unpkg.com/swagger-ui-dist@3/swagger-ui-standalone-preset.js",
15
+ summary="API to perform generative prompt completion using small LLM (without GPU).",
16
+ )
17
+
18
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
19
 
20
  @app.get("/")
21
  def read_root():
22
+ return {'detail': 'API running. Try out the endpoints in swagger'}
23
 
24
  @app.get("/api/check-heartbeat")
25
  def get_heartbeat():
26
+ return {"detail": "seems to be working"}
27
+
28
+ @app.post("/api/generatel-language", summary="Generate text from prompt", tags=["Generate"])
29
+ def inference(request: Request, input_prompt: SomeText):
30
+
31
+ if len(input_prompt.text) >0:
32
+
33
+ output = pipe_flan(input_prompt.text)
34
+ return {"output": output[0]["generated_text"]}
35
 
36
+ else:
37
+ raise HTTPException(status_code=400, detail = ErrorCodes.REQUEST_VALIDATION_ERROR.value[1])
 
 
payload.py CHANGED
@@ -12,6 +12,4 @@ class SomeText(BaseModel):
12
  }
13
  ]
14
  }
15
- }
16
-
17
-
 
12
  }
13
  ]
14
  }
15
+ }