File size: 2,226 Bytes
5795073
 
 
 
 
 
 
 
 
 
 
 
 
 
34fb87e
5795073
 
 
 
a5e5b55
5795073
 
 
 
 
 
 
 
34fb87e
5795073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5e5b55
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import os
import numpy as np
import gradio as gr
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from datasets.dataset_dict import DatasetDict
from transformers import AdamW, T5ForConditionalGeneration, T5TokenizerFast
import warnings
warnings.simplefilter('ignore')

from summarizer import SummarizerModel
from transformers import AutoTokenizer
MODEL_NAME = 'Salesforce/codet5-base-multi-sum'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = SummarizerModel(MODEL_NAME)
model.load_state_dict(torch.load('codet5-base-1_epoch-val_loss-0.80.pth'))

def summarize(text: str, 
              tokenizer = tokenizer,
              trained_model = model):
    """
    Summarizes a given code in text format.
    Args:
        text: The code in string format that needs to be summarized.
        tokenizer: The tokenizer used in the trained T5 model.
        trained_model: A SummarizerModel fine-tuned instance of 
        T5 model family.
    """
    text_encoding = tokenizer.encode_plus(
            text,
            padding = 'max_length',
            max_length = 512,
            add_special_tokens = True,
            return_attention_mask = True,
            truncation = True,
            return_tensors = 'pt'
        )
    generated_ids = trained_model.model.generate(
        input_ids = text_encoding['input_ids'],
        attention_mask = text_encoding['attention_mask'],
        max_length = 150,
        num_beams = 2,
        repetition_penalty = 2.5,
        length_penalty = 1.0,
        early_stopping = True
    )
    preds = [tokenizer.decode(gen_id, skip_special_tokens = True,
                              clean_up_tokenization_spaces=True)
                                for gen_id in generated_ids]
    return "".join(preds)

outputs = gr.outputs.Textbox()
iface = gr.Interface(fn=summarize, 
                   inputs=['text'], 
                   outputs=outputs,
                   description="Demo for ForgeT5 | Input: A python code | Output: The code summarization")
iface.launch(inline = False)