|
import torch |
|
import os |
|
import numpy as np |
|
import gradio as gr |
|
import pytorch_lightning as pl |
|
from torch.utils.data import Dataset, DataLoader |
|
from datasets import load_dataset |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from datasets.dataset_dict import DatasetDict |
|
from transformers import AdamW, T5ForConditionalGeneration, T5TokenizerFast |
|
import warnings |
|
warnings.simplefilter('ignore') |
|
|
|
from summarizer import SummarizerModel |
|
from transformers import AutoTokenizer |
|
MODEL_NAME = 'Salesforce/codet5-base-multi-sum' |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = SummarizerModel(MODEL_NAME) |
|
model.load_state_dict(torch.load('codet5-base-1_epoch-val_loss-0.80.pth')) |
|
|
|
def summarize(text: str, |
|
tokenizer = tokenizer, |
|
trained_model = model): |
|
""" |
|
Summarizes a given code in text format. |
|
Args: |
|
text: The code in string format that needs to be summarized. |
|
tokenizer: The tokenizer used in the trained T5 model. |
|
trained_model: A SummarizerModel fine-tuned instance of |
|
T5 model family. |
|
""" |
|
text_encoding = tokenizer.encode_plus( |
|
text, |
|
padding = 'max_length', |
|
max_length = 512, |
|
add_special_tokens = True, |
|
return_attention_mask = True, |
|
truncation = True, |
|
return_tensors = 'pt' |
|
) |
|
generated_ids = trained_model.model.generate( |
|
input_ids = text_encoding['input_ids'], |
|
attention_mask = text_encoding['attention_mask'], |
|
max_length = 150, |
|
num_beams = 2, |
|
repetition_penalty = 2.5, |
|
length_penalty = 1.0, |
|
early_stopping = True |
|
) |
|
preds = [tokenizer.decode(gen_id, skip_special_tokens = True, |
|
clean_up_tokenization_spaces=True) |
|
for gen_id in generated_ids] |
|
return "".join(preds) |
|
|
|
outputs = gr.outputs.Textbox() |
|
iface = gr.Interface(fn=summarize, |
|
inputs=['text'], |
|
outputs=outputs, |
|
description="Demo for ForgeT5 | Input: A python code | Output: The code summarization") |
|
iface.launch(inline = False) |