AP\VivekIsh commited on
Commit
6fadbbc
1 Parent(s): 0630fa6

codegen: Stage the code

Browse files
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *$py.class
4
+
5
+ # VS code files
6
+ .vscode
7
+
8
+ # PyInstaller
9
+ # Usually these files are written by a python script from a template
10
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
11
+ *.manifest
12
+ *.spec
13
+
14
+ # Logs
15
+ logs.log
16
+
17
+ # Jupyter Notebook
18
+ .ipynb_checkpoints
19
+
20
+
21
+ # Environments
22
+ .env
23
+ .venv
24
+ env/
25
+ venv/
26
+ ENV/
27
+ env.bak/
28
+ venv.bak/
README.md CHANGED
@@ -1,10 +1,37 @@
1
- ---
2
- title: Codegen
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UNISYS Codegen
2
+
3
+ A FastAPI project use to generate the code snippets based on the input given.
4
+
5
+ We are using Granite-3B-Code-Base code model designed for code generative tasks.
6
+
7
+
8
+ ## 🌞 How to start app
9
+ Make sure you have python 3.10 or higher installed.
10
+ First install the dependencies.
11
+
12
+ Installation of transformers:
13
+ ```
14
+ git clone https://github.com/huggingface/transformers
15
+ cd transformers/
16
+ pip install ./
17
+ ```
18
+
19
+ Installing required packages:
20
+
21
+ ```bash
22
+ pip install -r requirements.txt
23
+ ```
24
+
25
+ ### Manually
26
+
27
+ - Start with unicorn
28
+ ```bash
29
+ pip install unicorn
30
+ uvicorn index:app --reload
31
+ ```
32
+
33
+ ### API documentation: (Change the doamin-port accordingly)
34
+
35
+ ```
36
+ http://localhost:8000/docs/
37
+ ```
controllers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi_router_controller import ControllerLoader
3
+
4
+ this_dir = os.path.dirname(__file__)
5
+
6
+ ControllerLoader.load(this_dir, __package__)
controllers/generate_controller.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from fastapi_router_controller import Controller
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+ from services.ibm_model.ibm_extract_code_block import IbmExtractCodeblock
5
+ from services.ibm_model.ibm_text_generator import IbmTextGenerator
6
+ from services.model_generator import ModelGenerator
7
+
8
+ from utils.logger import Logger
9
+
10
+ logger = Logger.get_logger(__name__)
11
+
12
+ router = APIRouter(prefix='/v1')
13
+ controller = Controller(router, openapi_tag={
14
+ 'name': 'Generate the Code Snippets',
15
+ })
16
+
17
+ model = ModelGenerator()
18
+
19
+
20
+ @controller.use()
21
+ @controller.resource()
22
+ class GenerateController():
23
+ def __init__(
24
+ self,
25
+ service: ModelGenerator = Depends()) -> None:
26
+ self.model_generator = service
27
+
28
+ @controller.route.get(
29
+ '/generate',
30
+ tags=['generate-code'],
31
+ summary='Generates the code for the given input')
32
+ async def generate_code(self, input: str):
33
+ try:
34
+ if not input:
35
+ logger.error('Input is required.')
36
+ raise HTTPException(
37
+ status_code=500, detail='Input is required.')
38
+
39
+ ibm_generate_text_visitor = IbmTextGenerator()
40
+ generated_text = await self.model_generator.acceptTextGenerator(ibm_generate_text_visitor, input)
41
+
42
+ ibm_extract_code_block_visitor = IbmExtractCodeblock()
43
+ code_block = self.model_generator.acceptExtractCodeBlock(
44
+ ibm_extract_code_block_visitor, generated_text)
45
+
46
+ logger.info('Output: {}'.format(generated_text))
47
+
48
+ return {"data": code_block}
49
+ except asyncio.CancelledError:
50
+ logger.error(
51
+ 'Canceling network request due to disconnect in client.')
52
+ except Exception as error:
53
+ logger.error('Error {}'.format(error))
index.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import controllers
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi_router_controller import Controller, ControllersTags
5
+
6
+ from utils.config import Config
7
+ from utils.middleware import LogIncomingRequest
8
+ from utils.middleware.request_cancellation import RequestCancellation
9
+
10
+ #########################################
11
+ #### Configure the main application #####
12
+ #########################################
13
+ app = FastAPI(
14
+ title='{}'.format(Config.read('app', 'name')),
15
+ openapi_tags=ControllersTags)
16
+
17
+ app.add_middleware(LogIncomingRequest)
18
+ app.add_middleware(RequestCancellation)
19
+
20
+ for router in Controller.routers():
21
+ app.include_router(router)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fastapi
2
+ fastapi-router-controller
services/ibm_model/ibm_extract_code_block.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from services.model_visitor import ModelVisitor
3
+
4
+
5
+ class IbmExtractCodeblock(ModelVisitor):
6
+
7
+ def visit(self, _, data):
8
+ return self._get_code_block(data)
9
+
10
+ def _get_code_block(self, data):
11
+ r"""
12
+ Extracts text blocks from the input string based on a specific pattern.
13
+ Args:
14
+ data (str): The input string containing text blocks.
15
+ Returns:
16
+ str: A text block of output which contains code extracted from the input string.
17
+ Regex Pattern:
18
+ (?:### Output: ([\s\S]*?))(?:\<\|endoftext\|\>|\Z)|```(?:\w+)?\n(.*?)\n```
19
+ - (?:### Output: ([\s\S]*?)): This part matches patterns that start with '### Output:'
20
+ followed by any characters including newlines, capturing them within a group.
21
+ - (?:\<\|endoftext\|\>|\Z): This part matches either the string <|endoftext|>
22
+ or the end of the string (\Z).
23
+ - |: This is an OR operator, meaning the regex will match either the pattern
24
+ before or after it.
25
+ - ```(?:\w+)?\n(.*?)\n```: This part matches patterns enclosed within backticks (```),
26
+ possibly preceded by one or more word characters (\w+), capturing any characters
27
+ including newlines.
28
+ """
29
+ pattern = r'(?:### Output: ([\s\S]*?))(?:\<\|endoftext\|\>|\Z)|```(?:\w+)?\n(.*?)\n```'
30
+ matches = re.findall(pattern, data, re.DOTALL)
31
+ code = []
32
+ for match in matches:
33
+ if match[0]:
34
+ code.append(match[0].strip())
35
+ elif match[1]:
36
+ code.append(match[1].strip())
37
+ return ''.join(code)
services/ibm_model/ibm_text_generator.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import torch
3
+ from datetime import datetime
4
+ from services.model_visitor import ModelVisitor
5
+ from utils.logger import Logger
6
+
7
+ logger = Logger.get_logger(__name__)
8
+
9
+
10
+ class IbmTextGenerator(ModelVisitor):
11
+
12
+ async def visit(self, model_generator, input_text, max_length_per_chunk=50):
13
+ return await self._generate_text(model_generator, input_text, max_length_per_chunk)
14
+
15
+ async def _generate_text_chunk(self, model_generator, input_ids, max_length_per_chunk):
16
+ with torch.no_grad():
17
+ outputs = await asyncio.to_thread(model_generator.model.generate, input_ids, max_new_tokens=max_length_per_chunk)
18
+ continuation = model_generator.tokenizer.decode(
19
+ outputs[0], skip_special_tokens=False)
20
+
21
+ logger.info('Chunk generated: {}'.format(continuation))
22
+
23
+ return continuation
24
+
25
+ async def _generate_text(self, model_generator, input_text, max_length_per_chunk):
26
+ """
27
+ Generates the text based on input provided
28
+ Args:
29
+ input_text (str): The input string containing text blocks.
30
+ max_length_per_chunk: Max length per chunk (Default: 50 / Optional)
31
+ """
32
+ try:
33
+ start_time = datetime.now()
34
+
35
+ logger.info('Started at: {}'.format(
36
+ start_time.strftime(model_generator._format_data_time)))
37
+
38
+ input_ids = model_generator.tokenizer.encode(
39
+ input_text, return_tensors='pt').to(model_generator.device)
40
+
41
+ output_text = input_text
42
+ while True:
43
+ continuation = await self._generate_text_chunk(
44
+ model_generator, input_ids, max_length_per_chunk)
45
+
46
+ new_text = continuation[len(model_generator.tokenizer.decode(
47
+ input_ids[0], skip_special_tokens=False)):]
48
+ output_text += new_text
49
+ input_ids = model_generator.tokenizer.encode(
50
+ output_text, return_tensors='pt').to(model_generator.device)
51
+ if "<|endoftext|>" in new_text or new_text.count('```') > 1:
52
+ break
53
+
54
+ end_time = datetime.now()
55
+
56
+ logger.info('Output generated at: {}'.format(
57
+ end_time.strftime(model_generator._format_data_time)))
58
+
59
+ logger.info('Time taken: {}'.format(end_time - start_time))
60
+
61
+ return output_text
62
+ except asyncio.CancelledError:
63
+ logger.error(
64
+ 'Cancelling model generation due to disconnection in network.')
65
+ return ""
services/model_generator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from utils.config import Config
4
+ from utils.logger import Logger
5
+
6
+ logger = Logger.get_logger(__name__)
7
+
8
+
9
+ class ModelGenerator:
10
+ """
11
+ Singleton class responsible for generating text using a specified language model.
12
+
13
+ This class initializes a language model and tokenizer, and provides methods
14
+ to generate text and extract code blocks from generated text.
15
+
16
+ Attributes:
17
+ device (torch.device): Device to run the model on (CPU or GPU).
18
+ model (AutoModelForCausalLM): Language model for text generation.
19
+ tokenizer (AutoTokenizer): Tokenizer corresponding to the language model.
20
+
21
+ Methods:
22
+ acceptTextGenerator(self, visitor, *args, **kwargs):
23
+ Accepts a visitor to generates text based on the input provided with the model generator.
24
+ acceptExtractCodeBlock(self, visitor, *args, **kwargs):
25
+ Accepts a visitor to extract code blocks from the output text.
26
+ """
27
+ _instance = None
28
+ _format_data_time = "%Y-%m-%d %H:%M:%S"
29
+
30
+ def __new__(cls, model_name=Config.read('app', 'model')):
31
+ if cls._instance is None:
32
+ cls._instance = super(ModelGenerator, cls).__new__(cls)
33
+ cls._instance._initialize(model_name)
34
+ return cls._instance
35
+
36
+ def _initialize(self, model_name):
37
+ self.device = torch.device(
38
+ "cuda" if torch.cuda.is_available() else "cpu")
39
+ self.model = AutoModelForCausalLM.from_pretrained(
40
+ model_name).to(self.device)
41
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+
43
+ def acceptTextGenerator(self, visitor, *args, **kwargs):
44
+ return visitor.visit(self, *args, **kwargs)
45
+
46
+ def acceptExtractCodeBlock(self, visitor, *args, **kwargs):
47
+ return visitor.visit(self, *args, **kwargs)
services/model_visitor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class ModelVisitor(ABC):
5
+ """
6
+ Abstract base class for model visitors.
7
+
8
+ This class defines the interface for visiting a model generator.
9
+ Subclasses must implement the visit method to define
10
+ specific behaviors for different types of model generators.
11
+
12
+ Methods:
13
+ visit(generator, *args, **kwargs):
14
+ Abstract method to visit the model generator. Subclasses
15
+ must override this method to provide specific functionality.
16
+
17
+ Example:
18
+ class IbmTextGenerator(ModelVisitor):
19
+ def visit(self, model_generator, *args, **kwargs):
20
+ # Implement specific behavior here
21
+ pass
22
+ """
23
+
24
+ @abstractmethod
25
+ def visit(self, generator, *args, **kwargs):
26
+ pass
utils/config/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from configparser import ConfigParser
4
+
5
+ this_dir = Path(__file__).parent
6
+ conf_dir = this_dir / 'properties.ini'
7
+
8
+ parser = ConfigParser(os.environ)
9
+ parser.read(conf_dir, encoding="utf8")
10
+
11
+
12
+ class Config():
13
+ @staticmethod
14
+ def read(section, property, default=None):
15
+ return parser.get(section, property) or default
utils/config/properties.ini ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [app]
2
+ name=Unisys Code Gen
3
+ env=%(ENV)
4
+
5
+ # we are currently using this model
6
+ model=ibm-granite/granite-3b-code-base
7
+
8
+ [log]
9
+ level=INFO
10
+ filename=./logs.log
11
+ dateformat=%%Y-%%m-%%dT%%H:%%M:%%S
12
+ format=%%(asctime)s.%%(msecs)03d %%(levelname)5s %%(name)s - %%(message)s
utils/logger.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from utils.config import Config
4
+
5
+
6
+ class Logger():
7
+ CONFIG_KEY = 'log'
8
+
9
+ @staticmethod
10
+ def get_level():
11
+ return Config.read(Logger.CONFIG_KEY, 'level')
12
+
13
+ @staticmethod
14
+ def get_filename():
15
+ return Config.read(Logger.CONFIG_KEY, 'filename')
16
+
17
+ @staticmethod
18
+ def get_format():
19
+ return Config.read(Logger.CONFIG_KEY, 'format')
20
+
21
+ @staticmethod
22
+ def get_date_format():
23
+ return Config.read(Logger.CONFIG_KEY, 'dateformat')
24
+
25
+ @staticmethod
26
+ def get_logger(name):
27
+ logger = logging.getLogger(name)
28
+ logger.setLevel(Logger.get_level()) # type: ignore
29
+
30
+ formatter = logging.Formatter(
31
+ Logger.get_format(),
32
+ Logger.get_date_format())
33
+
34
+ file_hdlr = logging.FileHandler(Logger.get_filename()) # type: ignore
35
+ file_hdlr.setFormatter(formatter)
36
+ logger.addHandler(hdlr=file_hdlr)
37
+
38
+ return logger
utils/middleware/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from utils.middleware.log_incoming_request import LogIncomingRequest
2
+
3
+ __all__ = [
4
+ 'LogIncomingRequest'
5
+ ]
utils/middleware/log_incoming_request.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List
3
+ from fastapi import Request
4
+ from fastapi.routing import APIRoute
5
+ from starlette.middleware.base import BaseHTTPMiddleware
6
+ from utils.logger import Logger
7
+
8
+ logger = Logger.get_logger(__name__)
9
+
10
+
11
+ class LogIncomingRequest(BaseHTTPMiddleware):
12
+ def __get_request_handler(_, req: Request): # type: ignore
13
+ # get controller from request
14
+ routes: List[APIRoute] = req.app.routes
15
+ for route in routes:
16
+ if route.path_regex.match(req.url.path) and req.method in route.methods:
17
+ return route.endpoint.__name__ if hasattr(route.endpoint, '__name__') else 'fastapi_core'
18
+
19
+ async def dispatch(self, request: Request, call_next):
20
+ func_name = self.__get_request_handler(request)
21
+ request.state.func_name = func_name
22
+
23
+ logger.info('{} - start'.format(func_name))
24
+ start_time = time.time()
25
+
26
+ response = await call_next(request)
27
+
28
+ process_time = (time.time() - start_time) * 1000
29
+ formatted_process_time = '{0:.2f}'.format(process_time)
30
+ logger.info('{} - end in time (ms): {}'.format(func_name,
31
+ formatted_process_time))
32
+ return response
utils/middleware/request_cancellation.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from utils.logger import Logger
3
+
4
+ logger = Logger.get_logger(__name__)
5
+
6
+
7
+ class RequestCancellation:
8
+ """
9
+ RequestCancellation middleware handles request canceling
10
+ * In case of API routes where very frequent/expensive requests are made.
11
+ """
12
+
13
+ def __init__(self, app):
14
+ self.app = app
15
+
16
+ async def __call__(self, scope, receive, send):
17
+ if scope["type"] != "http":
18
+ await self.app(scope, receive, send)
19
+ return
20
+
21
+ queue = asyncio.Queue()
22
+
23
+ async def message_poller(sentinel, handler_task):
24
+ nonlocal queue
25
+ while True:
26
+ message = await receive()
27
+ if message["type"] == "http.disconnect":
28
+ handler_task.cancel()
29
+ return sentinel
30
+ await queue.put(message)
31
+
32
+ sentinel = object()
33
+ handler_task = asyncio.create_task(self.app(scope, queue.get, send))
34
+ asyncio.create_task(message_poller(sentinel, handler_task))
35
+
36
+ try:
37
+ return await handler_task
38
+ except asyncio.CancelledError:
39
+ logger.info('Task Cancellatation Requested.')