Rohan Kumar Singh
no chat no save
a5078fe
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AdamW
import pandas as pd
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import pad_sequence
# from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
pl.seed_everything(100)
MODEL_NAME='t5-base'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_MAX_LEN = 128
OUTPUT_MAX_LEN = 128
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)
class T5Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)
def forward(self, input_ids, attention_mask, labels=None):
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels= batch["target"]
loss, logits = self(input_ids , attention_mask, labels)
self.log("train_loss", loss, prog_bar=True, logger=True)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels= batch["target"]
loss, logits = self(input_ids, attention_mask, labels)
self.log("val_loss", loss, prog_bar=True, logger=True)
return {'val_loss': loss}
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)
train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)
train_model.freeze()
def generate_response(question):
inputs_encoding = tokenizer(
question,
add_special_tokens=True,
max_length= INPUT_MAX_LEN,
padding = 'max_length',
truncation='only_first',
return_attention_mask=True,
return_tensors="pt"
)
generate_ids = train_model.model.generate(
input_ids = inputs_encoding["input_ids"],
attention_mask = inputs_encoding["attention_mask"],
max_length = INPUT_MAX_LEN,
num_beams = 4,
num_return_sequences = 1,
no_repeat_ngram_size=2,
early_stopping=True,
)
preds = [
tokenizer.decode(gen_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
for gen_id in generate_ids
]
return "".join(preds)
import uuid
import datetime
import os
import streamlit as st
from streamlit_chat import message
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
password=os.getenv("mongo_pass")
uri = "mongodb+srv://rohank587:"+password+"@rkcluster.e3fpzja.mongodb.net/?retryWrites=true&w=majority"
# Create a new client and connect to the server
client = MongoClient(uri, server_api=ServerApi('1'))
st.title(":red[_Sarcastic_] Chatbot")
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'messages' not in st.session_state:
st.session_state['messages'] = [
{"role": "system", "content": "You are a helpful assistant."}
]
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_input("You:", key='input',placeholder="Disclaimer: Be careful with punctuations like , ? . ! \"")
submit_button = st.form_submit_button(label='Send',use_container_width=True)
col1,col2=st.columns(2)
with col1:
clear_button = st.button("Clear Conversation", key="clear",use_container_width=True)
with col2:
save_button = st.button("Save Conversation", key="save",use_container_width=True)
down_id = st.text_input('Enter ID to download chat',placeholder="Message ID")
if down_id:
info=client['rohank']['table1']
data=info.find_one({'message_id':down_id})
down_button = st.download_button('Download chat', "\n".join(data['message']),file_name="sar_chat.txt")
# reset everything
if clear_button:
st.session_state['generated'] = []
st.session_state['past'] = []
st.session_state['messages'] = [
{"role": "system", "content": "You are a helpful assistant."}
]
if save_button and st.session_state['generated'] and st.session_state['past']:
# Send a ping to confirm a successful connection
try:
client.admin.command('ping')
st.success("Pinged your deployment. You successfully connected to MongoDB! Saved Successfully.")
info=client['rohank']['table1']
chats=list([])
for i in range(len(st.session_state['generated'])):
chats.append("You: "+st.session_state['past'][i])
chats.append("Bot: "+st.session_state['generated'][i])
id=uuid.uuid4()
time=datetime.datetime.now()
info.insert_one({"time of saving":time.strftime("%c"),"message_id":str(id),"message":chats})
st.success("Copy this id "+str(id)+" for downloading saved chat anytime anywhere and then paste it down below!")
except Exception as e:
st.error("Can't connect to MongoDB. Save Failed.")
if submit_button and user_input:
output = generate_response(user_input)
st.session_state['past'].append(user_input)
st.session_state['generated'].append(output)
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))