from transformers import T5Tokenizer, T5ForConditionalGeneration from transformers import AdamW import pandas as pd import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from torch.nn.utils.rnn import pad_sequence # from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler pl.seed_everything(100) MODEL_NAME='t5-base' DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') INPUT_MAX_LEN = 128 OUTPUT_MAX_LEN = 128 tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512) class T5Model(pl.LightningModule): def __init__(self): super().__init__() self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True) def forward(self, input_ids, attention_mask, labels=None): output = self.model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) return output.loss, output.logits def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels= batch["target"] loss, logits = self(input_ids , attention_mask, labels) self.log("train_loss", loss, prog_bar=True, logger=True) return {'loss': loss} def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels= batch["target"] loss, logits = self(input_ids, attention_mask, labels) self.log("val_loss", loss, prog_bar=True, logger=True) return {'val_loss': loss} def configure_optimizers(self): return AdamW(self.parameters(), lr=0.0001) train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE) train_model.freeze() def generate_question(question): inputs_encoding = tokenizer( question, add_special_tokens=True, max_length= INPUT_MAX_LEN, padding = 'max_length', truncation='only_first', return_attention_mask=True, return_tensors="pt" ) generate_ids = train_model.model.generate( input_ids = inputs_encoding["input_ids"], attention_mask = inputs_encoding["attention_mask"], max_length = INPUT_MAX_LEN, num_beams = 4, num_return_sequences = 1, no_repeat_ngram_size=2, early_stopping=True, ) preds = [ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generate_ids ] return "".join(preds) import gradio as gr import random import time with gr.Blocks() as demo: chatbot = gr.Chatbot() gr.Chatbot.style(chatbot,height=300) with gr.Row(): with gr.Column(scale=0.98): msg = gr.Textbox( show_label=False, placeholder=random.choice(["Disclaimer: IT WILL CUSS YOU.", "Be careful with Punctuations like ? \" ! , \' .", "Enter text and press enter"]) ).style(container=False) with gr.Column(scale=0.1, min_width=0): sub = gr.Button("Send") clear = gr.Button("Clear") def user(user_message, history): return "", history + [[user_message, None]] def bot(history): bot_message = generate_question(history[-1][0]) history[-1][1] = "" for character in bot_message: history[-1][1] += character time.sleep(0.05) yield history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then( bot, chatbot, chatbot ) sub.click(user, [msg, chatbot], [msg, chatbot], queue=True).then( bot, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=True) demo.queue(concurrency_count=1) demo.launch()