Rohan Kumar Singh commited on
Commit
7191a40
1 Parent(s): db3a13a

initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +138 -0
  3. best-model.ckpt +3 -0
  4. requirements.txt +7 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ best-model.ckpt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
2
+
3
+ from transformers import AdamW
4
+ import pandas as pd
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ # from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
10
+
11
+ pl.seed_everything(100)
12
+
13
+ MODEL_NAME='t5-base'
14
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ INPUT_MAX_LEN = 128
16
+ OUTPUT_MAX_LEN = 128
17
+
18
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)
19
+
20
+ class T5Model(pl.LightningModule):
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+ self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)
25
+
26
+
27
+ def forward(self, input_ids, attention_mask, labels=None):
28
+
29
+ output = self.model(
30
+ input_ids=input_ids,
31
+ attention_mask=attention_mask,
32
+ labels=labels
33
+ )
34
+ return output.loss, output.logits
35
+
36
+ def training_step(self, batch, batch_idx):
37
+
38
+ input_ids = batch["input_ids"]
39
+ attention_mask = batch["attention_mask"]
40
+ labels= batch["target"]
41
+ loss, logits = self(input_ids , attention_mask, labels)
42
+
43
+
44
+ self.log("train_loss", loss, prog_bar=True, logger=True)
45
+
46
+ return {'loss': loss}
47
+
48
+ def validation_step(self, batch, batch_idx):
49
+ input_ids = batch["input_ids"]
50
+ attention_mask = batch["attention_mask"]
51
+ labels= batch["target"]
52
+ loss, logits = self(input_ids, attention_mask, labels)
53
+
54
+ self.log("val_loss", loss, prog_bar=True, logger=True)
55
+
56
+ return {'val_loss': loss}
57
+
58
+ def configure_optimizers(self):
59
+ return AdamW(self.parameters(), lr=0.0001)
60
+
61
+ train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)
62
+ train_model.freeze()
63
+
64
+ def generate_response(question):
65
+
66
+ inputs_encoding = tokenizer(
67
+ question,
68
+ add_special_tokens=True,
69
+ max_length= INPUT_MAX_LEN,
70
+ padding = 'max_length',
71
+ truncation='only_first',
72
+ return_attention_mask=True,
73
+ return_tensors="pt"
74
+ )
75
+
76
+
77
+ generate_ids = train_model.model.generate(
78
+ input_ids = inputs_encoding["input_ids"],
79
+ attention_mask = inputs_encoding["attention_mask"],
80
+ max_length = INPUT_MAX_LEN,
81
+ num_beams = 4,
82
+ num_return_sequences = 1,
83
+ no_repeat_ngram_size=2,
84
+ early_stopping=True,
85
+ )
86
+
87
+ preds = [
88
+ tokenizer.decode(gen_id,
89
+ skip_special_tokens=True,
90
+ clean_up_tokenization_spaces=True)
91
+ for gen_id in generate_ids
92
+ ]
93
+
94
+ return "".join(preds)
95
+
96
+
97
+ import streamlit as st
98
+ from streamlit_chat import message
99
+
100
+ if 'generated' not in st.session_state:
101
+ st.session_state['generated'] = []
102
+ if 'past' not in st.session_state:
103
+ st.session_state['past'] = []
104
+ if 'messages' not in st.session_state:
105
+ st.session_state['messages'] = [
106
+ {"role": "system", "content": "You are a helpful assistant."}
107
+ ]
108
+
109
+
110
+
111
+ # container for chat history
112
+ response_container = st.container()
113
+ # container for text box
114
+ container = st.container()
115
+
116
+ with container:
117
+ with st.form(key='my_form', clear_on_submit=True):
118
+ user_input = st.text_input("You:", key='input')
119
+ submit_button = st.form_submit_button(label='Send')
120
+ clear_button = st.button("Clear Conversation", key="clear")
121
+ # reset everything
122
+ if clear_button:
123
+ st.session_state['generated'] = []
124
+ st.session_state['past'] = []
125
+ st.session_state['messages'] = [
126
+ {"role": "system", "content": "You are a helpful assistant."}
127
+ ]
128
+
129
+ if submit_button and user_input:
130
+ output = generate_response(user_input)
131
+ st.session_state['past'].append(user_input)
132
+ st.session_state['generated'].append(output)
133
+
134
+ if st.session_state['generated']:
135
+ with response_container:
136
+ for i in range(len(st.session_state['generated'])):
137
+ message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
138
+ message(st.session_state["generated"][i], key=str(i))
best-model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9efe3a4fe521ae871e5c1329b9c0a954e11b1b4c9cde89631b4addd3a6418942
3
+ size 2675123319
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.27.4
2
+ pandas==1.5.3
3
+ torch==2.0.0
4
+ pytorch-lightning==2.0.2
5
+ sentencepiece==0.1.98
6
+ streamlit==1.20.0
7
+ streamlit-chat==0.0.2.2