Rohan Kumar Singh
included Send button
e81bdf9
raw
history blame contribute delete
No virus
3.96 kB
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AdamW
import pandas as pd
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
# from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
pl.seed_everything(100)
MODEL_NAME='t5-base'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_MAX_LEN = 128
OUTPUT_MAX_LEN = 128
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)
class T5Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)
def forward(self, input_ids, attention_mask, labels=None):
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels= batch["target"]
loss, logits = self(input_ids , attention_mask, labels)
self.log("train_loss", loss, prog_bar=True, logger=True)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels= batch["target"]
loss, logits = self(input_ids, attention_mask, labels)
self.log("val_loss", loss, prog_bar=True, logger=True)
return {'val_loss': loss}
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)
train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)
train_model.freeze()
def generate_question(question):
inputs_encoding = tokenizer(
question,
add_special_tokens=True,
max_length= INPUT_MAX_LEN,
padding = 'max_length',
truncation='only_first',
return_attention_mask=True,
return_tensors="pt"
)
generate_ids = train_model.model.generate(
input_ids = inputs_encoding["input_ids"],
attention_mask = inputs_encoding["attention_mask"],
max_length = INPUT_MAX_LEN,
num_beams = 4,
num_return_sequences = 1,
no_repeat_ngram_size=2,
early_stopping=True,
)
preds = [
tokenizer.decode(gen_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
for gen_id in generate_ids
]
return "".join(preds)
import gradio as gr
import random
import time
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
gr.Chatbot.style(chatbot,height=300)
with gr.Row():
with gr.Column(scale=0.98):
msg = gr.Textbox(
show_label=False,
placeholder=random.choice(["Disclaimer: IT WILL CUSS YOU.", "Be careful with Punctuations like ? \" ! , \' .", "Enter text and press enter"])
).style(container=False)
with gr.Column(scale=0.1, min_width=0):
sub = gr.Button("Send")
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
bot_message = generate_question(history[-1][0])
history[-1][1] = ""
for character in bot_message:
history[-1][1] += character
time.sleep(0.05)
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(
bot, chatbot, chatbot
)
sub.click(user, [msg, chatbot], [msg, chatbot], queue=True).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=True)
demo.queue(concurrency_count=1)
demo.launch()