Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import re
|
2 |
import streamlit as st
|
3 |
-
from transformers import pipeline, AutoTokenizer,
|
4 |
|
5 |
# Initialize the chat history
|
6 |
history = []
|
@@ -8,8 +8,8 @@ history = []
|
|
8 |
def clean_text(text):
|
9 |
return re.sub('[^a-zA-Z\s]', '', text).strip()
|
10 |
|
11 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
12 |
-
model =
|
13 |
|
14 |
def generate_response(user_input):
|
15 |
history.append((user_input, ""))
|
@@ -18,27 +18,28 @@ def generate_response(user_input):
|
|
18 |
return ""
|
19 |
|
20 |
last_user_message = history[-1][0]
|
21 |
-
combined_messages = "
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
input_ids =
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
-
history[-1] = (last_user_message,
|
41 |
-
return f"AI: {
|
42 |
|
43 |
st.title("Simple Chat App using DistilBert Model (HuggingFace & Streamlit)")
|
44 |
|
|
|
1 |
import re
|
2 |
import streamlit as st
|
3 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
|
5 |
# Initialize the chat history
|
6 |
history = []
|
|
|
8 |
def clean_text(text):
|
9 |
return re.sub('[^a-zA-Z\s]', '', text).strip()
|
10 |
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
12 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/DialoGPT-small").half().cuda()
|
13 |
|
14 |
def generate_response(user_input):
|
15 |
history.append((user_input, ""))
|
|
|
18 |
return ""
|
19 |
|
20 |
last_user_message = history[-1][0]
|
21 |
+
combined_messages = " ".join([msg for msg, _ in reversed(history[:-1])]) + " User: " + last_user_message
|
22 |
+
|
23 |
+
tokens = tokenizer.encode(combined_messages, add_special_tokens=True, max_length=4096, truncation=True)
|
24 |
+
tokens = tokens[:1024]
|
25 |
+
segment_ids = [0]*len(tokens)
|
26 |
+
input_ids = torch.tensor([tokens], dtype=torch.long).cuda()
|
27 |
+
|
28 |
+
with torch.no_grad():
|
29 |
+
outputs = model.generate(
|
30 |
+
input_ids,
|
31 |
+
max_length=1024,
|
32 |
+
min_length=20,
|
33 |
+
length_penalty=2.0,
|
34 |
+
early_stopping=True,
|
35 |
+
num_beams=4,
|
36 |
+
bad_words_callback=[lambda x: True if 'User:' in str(x) else False]
|
37 |
+
)
|
38 |
+
output = output[0].tolist()[len(tokens)-1:]
|
39 |
+
decoded_output = tokenizer.decode(output, skip_special_tokens=True)
|
40 |
|
41 |
+
history[-1] = (last_user_message, decoded_output)
|
42 |
+
return f"AI: {decoded_output}".capitalize()
|
43 |
|
44 |
st.title("Simple Chat App using DistilBert Model (HuggingFace & Streamlit)")
|
45 |
|