File size: 1,816 Bytes
6fadbbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import asyncio
from fastapi_router_controller import Controller
from fastapi import APIRouter, Depends, HTTPException
from services.ibm_model.ibm_extract_code_block import IbmExtractCodeblock
from services.ibm_model.ibm_text_generator import IbmTextGenerator
from services.model_generator import ModelGenerator

from utils.logger import Logger

logger = Logger.get_logger(__name__)

router = APIRouter(prefix='/v1')
controller = Controller(router, openapi_tag={
    'name': 'Generate the Code Snippets',
})

model = ModelGenerator()


@controller.use()
@controller.resource()
class GenerateController():
    def __init__(
            self,
            service: ModelGenerator = Depends()) -> None:
        self.model_generator = service

    @controller.route.get(
        '/generate',
        tags=['generate-code'],
        summary='Generates the code for the given input')
    async def generate_code(self, input: str):
        try:
            if not input:
                logger.error('Input is required.')
                raise HTTPException(
                    status_code=500, detail='Input is required.')

            ibm_generate_text_visitor = IbmTextGenerator()
            generated_text = await self.model_generator.acceptTextGenerator(ibm_generate_text_visitor, input)

            ibm_extract_code_block_visitor = IbmExtractCodeblock()
            code_block = self.model_generator.acceptExtractCodeBlock(
                ibm_extract_code_block_visitor, generated_text)

            logger.info('Output: {}'.format(generated_text))

            return {"data": code_block}
        except asyncio.CancelledError:
            logger.error(
                'Canceling network request due to disconnect in client.')
        except Exception as error:
            logger.error('Error {}'.format(error))