Spaces:
Sleeping
Sleeping
AP\VivekIsh
commited on
Commit
•
6fadbbc
1
Parent(s):
0630fa6
codegen: Stage the code
Browse files- .gitignore +28 -0
- README.md +37 -10
- controllers/__init__.py +6 -0
- controllers/generate_controller.py +53 -0
- index.py +21 -0
- requirements.txt +2 -0
- services/ibm_model/ibm_extract_code_block.py +37 -0
- services/ibm_model/ibm_text_generator.py +65 -0
- services/model_generator.py +47 -0
- services/model_visitor.py +26 -0
- utils/config/__init__.py +15 -0
- utils/config/properties.ini +12 -0
- utils/logger.py +38 -0
- utils/middleware/__init__.py +5 -0
- utils/middleware/log_incoming_request.py +32 -0
- utils/middleware/request_cancellation.py +39 -0
.gitignore
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*$py.class
|
4 |
+
|
5 |
+
# VS code files
|
6 |
+
.vscode
|
7 |
+
|
8 |
+
# PyInstaller
|
9 |
+
# Usually these files are written by a python script from a template
|
10 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
11 |
+
*.manifest
|
12 |
+
*.spec
|
13 |
+
|
14 |
+
# Logs
|
15 |
+
logs.log
|
16 |
+
|
17 |
+
# Jupyter Notebook
|
18 |
+
.ipynb_checkpoints
|
19 |
+
|
20 |
+
|
21 |
+
# Environments
|
22 |
+
.env
|
23 |
+
.venv
|
24 |
+
env/
|
25 |
+
venv/
|
26 |
+
ENV/
|
27 |
+
env.bak/
|
28 |
+
venv.bak/
|
README.md
CHANGED
@@ -1,10 +1,37 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UNISYS Codegen
|
2 |
+
|
3 |
+
A FastAPI project use to generate the code snippets based on the input given.
|
4 |
+
|
5 |
+
We are using Granite-3B-Code-Base code model designed for code generative tasks.
|
6 |
+
|
7 |
+
|
8 |
+
## 🌞 How to start app
|
9 |
+
Make sure you have python 3.10 or higher installed.
|
10 |
+
First install the dependencies.
|
11 |
+
|
12 |
+
Installation of transformers:
|
13 |
+
```
|
14 |
+
git clone https://github.com/huggingface/transformers
|
15 |
+
cd transformers/
|
16 |
+
pip install ./
|
17 |
+
```
|
18 |
+
|
19 |
+
Installing required packages:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
pip install -r requirements.txt
|
23 |
+
```
|
24 |
+
|
25 |
+
### Manually
|
26 |
+
|
27 |
+
- Start with unicorn
|
28 |
+
```bash
|
29 |
+
pip install unicorn
|
30 |
+
uvicorn index:app --reload
|
31 |
+
```
|
32 |
+
|
33 |
+
### API documentation: (Change the doamin-port accordingly)
|
34 |
+
|
35 |
+
```
|
36 |
+
http://localhost:8000/docs/
|
37 |
+
```
|
controllers/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from fastapi_router_controller import ControllerLoader
|
3 |
+
|
4 |
+
this_dir = os.path.dirname(__file__)
|
5 |
+
|
6 |
+
ControllerLoader.load(this_dir, __package__)
|
controllers/generate_controller.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from fastapi_router_controller import Controller
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
4 |
+
from services.ibm_model.ibm_extract_code_block import IbmExtractCodeblock
|
5 |
+
from services.ibm_model.ibm_text_generator import IbmTextGenerator
|
6 |
+
from services.model_generator import ModelGenerator
|
7 |
+
|
8 |
+
from utils.logger import Logger
|
9 |
+
|
10 |
+
logger = Logger.get_logger(__name__)
|
11 |
+
|
12 |
+
router = APIRouter(prefix='/v1')
|
13 |
+
controller = Controller(router, openapi_tag={
|
14 |
+
'name': 'Generate the Code Snippets',
|
15 |
+
})
|
16 |
+
|
17 |
+
model = ModelGenerator()
|
18 |
+
|
19 |
+
|
20 |
+
@controller.use()
|
21 |
+
@controller.resource()
|
22 |
+
class GenerateController():
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
service: ModelGenerator = Depends()) -> None:
|
26 |
+
self.model_generator = service
|
27 |
+
|
28 |
+
@controller.route.get(
|
29 |
+
'/generate',
|
30 |
+
tags=['generate-code'],
|
31 |
+
summary='Generates the code for the given input')
|
32 |
+
async def generate_code(self, input: str):
|
33 |
+
try:
|
34 |
+
if not input:
|
35 |
+
logger.error('Input is required.')
|
36 |
+
raise HTTPException(
|
37 |
+
status_code=500, detail='Input is required.')
|
38 |
+
|
39 |
+
ibm_generate_text_visitor = IbmTextGenerator()
|
40 |
+
generated_text = await self.model_generator.acceptTextGenerator(ibm_generate_text_visitor, input)
|
41 |
+
|
42 |
+
ibm_extract_code_block_visitor = IbmExtractCodeblock()
|
43 |
+
code_block = self.model_generator.acceptExtractCodeBlock(
|
44 |
+
ibm_extract_code_block_visitor, generated_text)
|
45 |
+
|
46 |
+
logger.info('Output: {}'.format(generated_text))
|
47 |
+
|
48 |
+
return {"data": code_block}
|
49 |
+
except asyncio.CancelledError:
|
50 |
+
logger.error(
|
51 |
+
'Canceling network request due to disconnect in client.')
|
52 |
+
except Exception as error:
|
53 |
+
logger.error('Error {}'.format(error))
|
index.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import controllers
|
2 |
+
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from fastapi_router_controller import Controller, ControllersTags
|
5 |
+
|
6 |
+
from utils.config import Config
|
7 |
+
from utils.middleware import LogIncomingRequest
|
8 |
+
from utils.middleware.request_cancellation import RequestCancellation
|
9 |
+
|
10 |
+
#########################################
|
11 |
+
#### Configure the main application #####
|
12 |
+
#########################################
|
13 |
+
app = FastAPI(
|
14 |
+
title='{}'.format(Config.read('app', 'name')),
|
15 |
+
openapi_tags=ControllersTags)
|
16 |
+
|
17 |
+
app.add_middleware(LogIncomingRequest)
|
18 |
+
app.add_middleware(RequestCancellation)
|
19 |
+
|
20 |
+
for router in Controller.routers():
|
21 |
+
app.include_router(router)
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
fastapi-router-controller
|
services/ibm_model/ibm_extract_code_block.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from services.model_visitor import ModelVisitor
|
3 |
+
|
4 |
+
|
5 |
+
class IbmExtractCodeblock(ModelVisitor):
|
6 |
+
|
7 |
+
def visit(self, _, data):
|
8 |
+
return self._get_code_block(data)
|
9 |
+
|
10 |
+
def _get_code_block(self, data):
|
11 |
+
r"""
|
12 |
+
Extracts text blocks from the input string based on a specific pattern.
|
13 |
+
Args:
|
14 |
+
data (str): The input string containing text blocks.
|
15 |
+
Returns:
|
16 |
+
str: A text block of output which contains code extracted from the input string.
|
17 |
+
Regex Pattern:
|
18 |
+
(?:### Output: ([\s\S]*?))(?:\<\|endoftext\|\>|\Z)|```(?:\w+)?\n(.*?)\n```
|
19 |
+
- (?:### Output: ([\s\S]*?)): This part matches patterns that start with '### Output:'
|
20 |
+
followed by any characters including newlines, capturing them within a group.
|
21 |
+
- (?:\<\|endoftext\|\>|\Z): This part matches either the string <|endoftext|>
|
22 |
+
or the end of the string (\Z).
|
23 |
+
- |: This is an OR operator, meaning the regex will match either the pattern
|
24 |
+
before or after it.
|
25 |
+
- ```(?:\w+)?\n(.*?)\n```: This part matches patterns enclosed within backticks (```),
|
26 |
+
possibly preceded by one or more word characters (\w+), capturing any characters
|
27 |
+
including newlines.
|
28 |
+
"""
|
29 |
+
pattern = r'(?:### Output: ([\s\S]*?))(?:\<\|endoftext\|\>|\Z)|```(?:\w+)?\n(.*?)\n```'
|
30 |
+
matches = re.findall(pattern, data, re.DOTALL)
|
31 |
+
code = []
|
32 |
+
for match in matches:
|
33 |
+
if match[0]:
|
34 |
+
code.append(match[0].strip())
|
35 |
+
elif match[1]:
|
36 |
+
code.append(match[1].strip())
|
37 |
+
return ''.join(code)
|
services/ibm_model/ibm_text_generator.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import torch
|
3 |
+
from datetime import datetime
|
4 |
+
from services.model_visitor import ModelVisitor
|
5 |
+
from utils.logger import Logger
|
6 |
+
|
7 |
+
logger = Logger.get_logger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class IbmTextGenerator(ModelVisitor):
|
11 |
+
|
12 |
+
async def visit(self, model_generator, input_text, max_length_per_chunk=50):
|
13 |
+
return await self._generate_text(model_generator, input_text, max_length_per_chunk)
|
14 |
+
|
15 |
+
async def _generate_text_chunk(self, model_generator, input_ids, max_length_per_chunk):
|
16 |
+
with torch.no_grad():
|
17 |
+
outputs = await asyncio.to_thread(model_generator.model.generate, input_ids, max_new_tokens=max_length_per_chunk)
|
18 |
+
continuation = model_generator.tokenizer.decode(
|
19 |
+
outputs[0], skip_special_tokens=False)
|
20 |
+
|
21 |
+
logger.info('Chunk generated: {}'.format(continuation))
|
22 |
+
|
23 |
+
return continuation
|
24 |
+
|
25 |
+
async def _generate_text(self, model_generator, input_text, max_length_per_chunk):
|
26 |
+
"""
|
27 |
+
Generates the text based on input provided
|
28 |
+
Args:
|
29 |
+
input_text (str): The input string containing text blocks.
|
30 |
+
max_length_per_chunk: Max length per chunk (Default: 50 / Optional)
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
start_time = datetime.now()
|
34 |
+
|
35 |
+
logger.info('Started at: {}'.format(
|
36 |
+
start_time.strftime(model_generator._format_data_time)))
|
37 |
+
|
38 |
+
input_ids = model_generator.tokenizer.encode(
|
39 |
+
input_text, return_tensors='pt').to(model_generator.device)
|
40 |
+
|
41 |
+
output_text = input_text
|
42 |
+
while True:
|
43 |
+
continuation = await self._generate_text_chunk(
|
44 |
+
model_generator, input_ids, max_length_per_chunk)
|
45 |
+
|
46 |
+
new_text = continuation[len(model_generator.tokenizer.decode(
|
47 |
+
input_ids[0], skip_special_tokens=False)):]
|
48 |
+
output_text += new_text
|
49 |
+
input_ids = model_generator.tokenizer.encode(
|
50 |
+
output_text, return_tensors='pt').to(model_generator.device)
|
51 |
+
if "<|endoftext|>" in new_text or new_text.count('```') > 1:
|
52 |
+
break
|
53 |
+
|
54 |
+
end_time = datetime.now()
|
55 |
+
|
56 |
+
logger.info('Output generated at: {}'.format(
|
57 |
+
end_time.strftime(model_generator._format_data_time)))
|
58 |
+
|
59 |
+
logger.info('Time taken: {}'.format(end_time - start_time))
|
60 |
+
|
61 |
+
return output_text
|
62 |
+
except asyncio.CancelledError:
|
63 |
+
logger.error(
|
64 |
+
'Cancelling model generation due to disconnection in network.')
|
65 |
+
return ""
|
services/model_generator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
from utils.config import Config
|
4 |
+
from utils.logger import Logger
|
5 |
+
|
6 |
+
logger = Logger.get_logger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class ModelGenerator:
|
10 |
+
"""
|
11 |
+
Singleton class responsible for generating text using a specified language model.
|
12 |
+
|
13 |
+
This class initializes a language model and tokenizer, and provides methods
|
14 |
+
to generate text and extract code blocks from generated text.
|
15 |
+
|
16 |
+
Attributes:
|
17 |
+
device (torch.device): Device to run the model on (CPU or GPU).
|
18 |
+
model (AutoModelForCausalLM): Language model for text generation.
|
19 |
+
tokenizer (AutoTokenizer): Tokenizer corresponding to the language model.
|
20 |
+
|
21 |
+
Methods:
|
22 |
+
acceptTextGenerator(self, visitor, *args, **kwargs):
|
23 |
+
Accepts a visitor to generates text based on the input provided with the model generator.
|
24 |
+
acceptExtractCodeBlock(self, visitor, *args, **kwargs):
|
25 |
+
Accepts a visitor to extract code blocks from the output text.
|
26 |
+
"""
|
27 |
+
_instance = None
|
28 |
+
_format_data_time = "%Y-%m-%d %H:%M:%S"
|
29 |
+
|
30 |
+
def __new__(cls, model_name=Config.read('app', 'model')):
|
31 |
+
if cls._instance is None:
|
32 |
+
cls._instance = super(ModelGenerator, cls).__new__(cls)
|
33 |
+
cls._instance._initialize(model_name)
|
34 |
+
return cls._instance
|
35 |
+
|
36 |
+
def _initialize(self, model_name):
|
37 |
+
self.device = torch.device(
|
38 |
+
"cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_name).to(self.device)
|
41 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
42 |
+
|
43 |
+
def acceptTextGenerator(self, visitor, *args, **kwargs):
|
44 |
+
return visitor.visit(self, *args, **kwargs)
|
45 |
+
|
46 |
+
def acceptExtractCodeBlock(self, visitor, *args, **kwargs):
|
47 |
+
return visitor.visit(self, *args, **kwargs)
|
services/model_visitor.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class ModelVisitor(ABC):
|
5 |
+
"""
|
6 |
+
Abstract base class for model visitors.
|
7 |
+
|
8 |
+
This class defines the interface for visiting a model generator.
|
9 |
+
Subclasses must implement the visit method to define
|
10 |
+
specific behaviors for different types of model generators.
|
11 |
+
|
12 |
+
Methods:
|
13 |
+
visit(generator, *args, **kwargs):
|
14 |
+
Abstract method to visit the model generator. Subclasses
|
15 |
+
must override this method to provide specific functionality.
|
16 |
+
|
17 |
+
Example:
|
18 |
+
class IbmTextGenerator(ModelVisitor):
|
19 |
+
def visit(self, model_generator, *args, **kwargs):
|
20 |
+
# Implement specific behavior here
|
21 |
+
pass
|
22 |
+
"""
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def visit(self, generator, *args, **kwargs):
|
26 |
+
pass
|
utils/config/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from configparser import ConfigParser
|
4 |
+
|
5 |
+
this_dir = Path(__file__).parent
|
6 |
+
conf_dir = this_dir / 'properties.ini'
|
7 |
+
|
8 |
+
parser = ConfigParser(os.environ)
|
9 |
+
parser.read(conf_dir, encoding="utf8")
|
10 |
+
|
11 |
+
|
12 |
+
class Config():
|
13 |
+
@staticmethod
|
14 |
+
def read(section, property, default=None):
|
15 |
+
return parser.get(section, property) or default
|
utils/config/properties.ini
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[app]
|
2 |
+
name=Unisys Code Gen
|
3 |
+
env=%(ENV)
|
4 |
+
|
5 |
+
# we are currently using this model
|
6 |
+
model=ibm-granite/granite-3b-code-base
|
7 |
+
|
8 |
+
[log]
|
9 |
+
level=INFO
|
10 |
+
filename=./logs.log
|
11 |
+
dateformat=%%Y-%%m-%%dT%%H:%%M:%%S
|
12 |
+
format=%%(asctime)s.%%(msecs)03d %%(levelname)5s %%(name)s - %%(message)s
|
utils/logger.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from utils.config import Config
|
4 |
+
|
5 |
+
|
6 |
+
class Logger():
|
7 |
+
CONFIG_KEY = 'log'
|
8 |
+
|
9 |
+
@staticmethod
|
10 |
+
def get_level():
|
11 |
+
return Config.read(Logger.CONFIG_KEY, 'level')
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def get_filename():
|
15 |
+
return Config.read(Logger.CONFIG_KEY, 'filename')
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def get_format():
|
19 |
+
return Config.read(Logger.CONFIG_KEY, 'format')
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def get_date_format():
|
23 |
+
return Config.read(Logger.CONFIG_KEY, 'dateformat')
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def get_logger(name):
|
27 |
+
logger = logging.getLogger(name)
|
28 |
+
logger.setLevel(Logger.get_level()) # type: ignore
|
29 |
+
|
30 |
+
formatter = logging.Formatter(
|
31 |
+
Logger.get_format(),
|
32 |
+
Logger.get_date_format())
|
33 |
+
|
34 |
+
file_hdlr = logging.FileHandler(Logger.get_filename()) # type: ignore
|
35 |
+
file_hdlr.setFormatter(formatter)
|
36 |
+
logger.addHandler(hdlr=file_hdlr)
|
37 |
+
|
38 |
+
return logger
|
utils/middleware/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.middleware.log_incoming_request import LogIncomingRequest
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'LogIncomingRequest'
|
5 |
+
]
|
utils/middleware/log_incoming_request.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import List
|
3 |
+
from fastapi import Request
|
4 |
+
from fastapi.routing import APIRoute
|
5 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
6 |
+
from utils.logger import Logger
|
7 |
+
|
8 |
+
logger = Logger.get_logger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
class LogIncomingRequest(BaseHTTPMiddleware):
|
12 |
+
def __get_request_handler(_, req: Request): # type: ignore
|
13 |
+
# get controller from request
|
14 |
+
routes: List[APIRoute] = req.app.routes
|
15 |
+
for route in routes:
|
16 |
+
if route.path_regex.match(req.url.path) and req.method in route.methods:
|
17 |
+
return route.endpoint.__name__ if hasattr(route.endpoint, '__name__') else 'fastapi_core'
|
18 |
+
|
19 |
+
async def dispatch(self, request: Request, call_next):
|
20 |
+
func_name = self.__get_request_handler(request)
|
21 |
+
request.state.func_name = func_name
|
22 |
+
|
23 |
+
logger.info('{} - start'.format(func_name))
|
24 |
+
start_time = time.time()
|
25 |
+
|
26 |
+
response = await call_next(request)
|
27 |
+
|
28 |
+
process_time = (time.time() - start_time) * 1000
|
29 |
+
formatted_process_time = '{0:.2f}'.format(process_time)
|
30 |
+
logger.info('{} - end in time (ms): {}'.format(func_name,
|
31 |
+
formatted_process_time))
|
32 |
+
return response
|
utils/middleware/request_cancellation.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from utils.logger import Logger
|
3 |
+
|
4 |
+
logger = Logger.get_logger(__name__)
|
5 |
+
|
6 |
+
|
7 |
+
class RequestCancellation:
|
8 |
+
"""
|
9 |
+
RequestCancellation middleware handles request canceling
|
10 |
+
* In case of API routes where very frequent/expensive requests are made.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, app):
|
14 |
+
self.app = app
|
15 |
+
|
16 |
+
async def __call__(self, scope, receive, send):
|
17 |
+
if scope["type"] != "http":
|
18 |
+
await self.app(scope, receive, send)
|
19 |
+
return
|
20 |
+
|
21 |
+
queue = asyncio.Queue()
|
22 |
+
|
23 |
+
async def message_poller(sentinel, handler_task):
|
24 |
+
nonlocal queue
|
25 |
+
while True:
|
26 |
+
message = await receive()
|
27 |
+
if message["type"] == "http.disconnect":
|
28 |
+
handler_task.cancel()
|
29 |
+
return sentinel
|
30 |
+
await queue.put(message)
|
31 |
+
|
32 |
+
sentinel = object()
|
33 |
+
handler_task = asyncio.create_task(self.app(scope, queue.get, send))
|
34 |
+
asyncio.create_task(message_poller(sentinel, handler_task))
|
35 |
+
|
36 |
+
try:
|
37 |
+
return await handler_task
|
38 |
+
except asyncio.CancelledError:
|
39 |
+
logger.info('Task Cancellatation Requested.')
|