karthickg12 commited on
Commit
341bd5d
1 Parent(s): bef6044

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -18
app.py CHANGED
@@ -87,25 +87,65 @@
87
  # text = tokenizer.batch_decode(outputs)[0]
88
  # print(text)
89
 
90
- import torch
91
- import streamlit as st
92
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
93
 
94
- model_name="facebook/blenderbot-400M-distill"
95
 
96
- model=AutoModelForSeq2SeqLM.from_pretrained(model_name)
97
- tokenizer = AutoTokenizer.from_pretrained(model_name)
98
- ch=[]
99
- def chat():
100
 
101
- h_s="\n".join(ch)
102
- i=st.text_input("enter")
103
- i_s=tokenizer.encode_plus(h_s,i,return_tensors="pt")
104
- outputs=model.generate(**i_s,max_length=60)
105
- response=tokenizer.decode(outputs[0],skip_special_tokens=True).strip()
106
- ch.append(i)
107
- ch.append(response)
108
- return response
109
- if __name__ == "__main__":
110
- chat()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
87
  # text = tokenizer.batch_decode(outputs)[0]
88
  # print(text)
89
 
90
+ # import torch
91
+ # import streamlit as st
92
+ # from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
93
 
94
+ # model_name="facebook/blenderbot-400M-distill"
95
 
96
+ # model=AutoModelForSeq2SeqLM.from_pretrained(model_name)
97
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
98
+ # ch=[]
99
+ # def chat():
100
 
101
+ # h_s="\n".join(ch)
102
+ # i=st.text_input("enter")
103
+ # i_s=tokenizer.encode_plus(h_s,i,return_tensors="pt")
104
+ # outputs=model.generate(**i_s,max_length=60)
105
+ # response=tokenizer.decode(outputs[0],skip_special_tokens=True).strip()
106
+ # ch.append(i)
107
+ # ch.append(response)
108
+ # return response
109
+ # if __name__ == "__main__":
110
+ # chat()
111
+
112
+ import streamlit as st
113
+ from streamlit_chat import message as st_message
114
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
115
+
116
+
117
+ @st.experimental_singleton
118
+ def get_models():
119
+ # Load the model and the tokenizer
120
+ tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M")
121
+
122
+ model = AutoModelForSeq2SeqLM.from_pretrained(
123
+ "facebook/blenderbot_small-90M")
124
+
125
+ return tokenizer, model
126
+
127
+
128
+ if "history" not in st.session_state:
129
+ st.session_state.history = []
130
+
131
+ st.title("Blenderbot")
132
+
133
+
134
+ def generate_answer():
135
+ tokenizer, model = get_models()
136
+ user_message = st.session_state.input_text
137
+ inputs = tokenizer(st.session_state.input_text, return_tensors="pt")
138
+ result = model.generate(**inputs)
139
+ message_bot = tokenizer.decode(
140
+ result[0], skip_special_tokens=True
141
+ ) # decode the result to a string
142
+
143
+ st.session_state.history.append({"message": user_message, "is_user": True})
144
+ st.session_state.history.append({"message": message_bot, "is_user": False})
145
+
146
+
147
+ st.text_input("Tap to chat with the bot",
148
+ key="input_text", on_change=generate_answer)
149
 
150
+ for chat in st.session_state.history:
151
+ st_message(**chat)