from transformers import T5Tokenizer, T5ForConditionalGeneration from transformers import AdamW import pandas as pd import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from torch.nn.utils.rnn import pad_sequence # from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler pl.seed_everything(100) MODEL_NAME='t5-base' DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') INPUT_MAX_LEN = 128 OUTPUT_MAX_LEN = 128 tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512) class T5Model(pl.LightningModule): def __init__(self): super().__init__() self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True) def forward(self, input_ids, attention_mask, labels=None): output = self.model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) return output.loss, output.logits def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels= batch["target"] loss, logits = self(input_ids , attention_mask, labels) self.log("train_loss", loss, prog_bar=True, logger=True) return {'loss': loss} def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels= batch["target"] loss, logits = self(input_ids, attention_mask, labels) self.log("val_loss", loss, prog_bar=True, logger=True) return {'val_loss': loss} def configure_optimizers(self): return AdamW(self.parameters(), lr=0.0001) train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE) train_model.freeze() def generate_response(question): inputs_encoding = tokenizer( question, add_special_tokens=True, max_length= INPUT_MAX_LEN, padding = 'max_length', truncation='only_first', return_attention_mask=True, return_tensors="pt" ) generate_ids = train_model.model.generate( input_ids = inputs_encoding["input_ids"], attention_mask = inputs_encoding["attention_mask"], max_length = INPUT_MAX_LEN, num_beams = 4, num_return_sequences = 1, no_repeat_ngram_size=2, early_stopping=True, ) preds = [ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generate_ids ] return "".join(preds) import uuid import datetime import os import streamlit as st from streamlit_chat import message from pymongo.mongo_client import MongoClient from pymongo.server_api import ServerApi password=os.getenv("mongo_pass") uri = "mongodb+srv://rohank587:"+password+"@rkcluster.e3fpzja.mongodb.net/?retryWrites=true&w=majority" # Create a new client and connect to the server client = MongoClient(uri, server_api=ServerApi('1')) st.title(":red[_Sarcastic_] Chatbot") if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] if 'messages' not in st.session_state: st.session_state['messages'] = [ {"role": "system", "content": "You are a helpful assistant."} ] # container for chat history response_container = st.container() # container for text box container = st.container() with container: with st.form(key='my_form', clear_on_submit=True): user_input = st.text_input("You:", key='input',placeholder="Disclaimer: Be careful with punctuations like , ? . ! \"") submit_button = st.form_submit_button(label='Send',use_container_width=True) col1,col2=st.columns(2) with col1: clear_button = st.button("Clear Conversation", key="clear",use_container_width=True) with col2: save_button = st.button("Save Conversation", key="save",use_container_width=True) down_id = st.text_input('Enter ID to download chat',placeholder="Message ID") if down_id: info=client['rohank']['table1'] data=info.find_one({'message_id':down_id}) down_button = st.download_button('Download chat', "\n".join(data['message']),file_name="sar_chat.txt") # reset everything if clear_button: st.session_state['generated'] = [] st.session_state['past'] = [] st.session_state['messages'] = [ {"role": "system", "content": "You are a helpful assistant."} ] if save_button and st.session_state['generated'] and st.session_state['past']: # Send a ping to confirm a successful connection try: client.admin.command('ping') st.success("Pinged your deployment. You successfully connected to MongoDB! Saved Successfully.") info=client['rohank']['table1'] chats=list([]) for i in range(len(st.session_state['generated'])): chats.append("You: "+st.session_state['past'][i]) chats.append("Bot: "+st.session_state['generated'][i]) id=uuid.uuid4() time=datetime.datetime.now() info.insert_one({"time of saving":time.strftime("%c"),"message_id":str(id),"message":chats}) st.success("Copy this id "+str(id)+" for downloading saved chat anytime anywhere and then paste it down below!") except Exception as e: st.error("Can't connect to MongoDB. Save Failed.") if submit_button and user_input: output = generate_response(user_input) st.session_state['past'].append(user_input) st.session_state['generated'].append(output) if st.session_state['generated']: with response_container: for i in range(len(st.session_state['generated'])): message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') message(st.session_state["generated"][i], key=str(i))