Hackavist commited on
Commit
9d79da1
1 Parent(s): ba2ad99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import re
2
  import streamlit as st
3
- from transformers import pipeline, AutoTokenizer, TFAutoModelForMaskedLM
4
 
5
  # Initialize the chat history
6
  history = []
@@ -8,8 +8,8 @@ history = []
8
  def clean_text(text):
9
  return re.sub('[^a-zA-Z\s]', '', text).strip()
10
 
11
- tokenizer = AutoTokenizer.from_pretrained("t5-small")
12
- model = TFAutoModelForMaskedLM.from_pretrained("t5-small").half().cuda()
13
 
14
  def generate_response(user_input):
15
  history.append((user_input, ""))
@@ -18,27 +18,28 @@ def generate_response(user_input):
18
  return ""
19
 
20
  last_user_message = history[-1][0]
21
- combined_messages = " Human: " + " . ".join([msg for msg, _ in reversed(history[:-1])]) + " . Human: " + last_user_message
22
- input_str = "summarize: " + combined_messages
23
- source_encodings = tokenizer.batch_encode_plus([input_str], pad_to_max_length=False, padding='max_length', return_attention_mask=True, return_tensors="tf")
24
- input_ids = source_encodings["input_ids"][0]
25
- attention_mask = source_encodings["attention_mask"][0]
26
- input_ids = tf.constant(input_ids)[None, :]
27
- attention_mask = tf.constant(attention_mask)[None, :]
28
-
29
- with tf.device('/GPU:0'):
30
- output = model.generate(
31
- input_ids,
32
- attention_mask=attention_mask,
33
- max_length=256,
34
- num_beams=4,
35
- early_stopping=True
36
- )
37
-
38
- predicted_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
 
39
 
40
- history[-1] = (last_user_message, predicted_sentence)
41
- return f"AI: {predicted_sentence}".capitalize()
42
 
43
  st.title("Simple Chat App using DistilBert Model (HuggingFace & Streamlit)")
44
 
 
1
  import re
2
  import streamlit as st
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  # Initialize the chat history
6
  history = []
 
8
  def clean_text(text):
9
  return re.sub('[^a-zA-Z\s]', '', text).strip()
10
 
11
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
12
+ model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/DialoGPT-small").half().cuda()
13
 
14
  def generate_response(user_input):
15
  history.append((user_input, ""))
 
18
  return ""
19
 
20
  last_user_message = history[-1][0]
21
+ combined_messages = " ".join([msg for msg, _ in reversed(history[:-1])]) + " User: " + last_user_message
22
+
23
+ tokens = tokenizer.encode(combined_messages, add_special_tokens=True, max_length=4096, truncation=True)
24
+ tokens = tokens[:1024]
25
+ segment_ids = [0]*len(tokens)
26
+ input_ids = torch.tensor([tokens], dtype=torch.long).cuda()
27
+
28
+ with torch.no_grad():
29
+ outputs = model.generate(
30
+ input_ids,
31
+ max_length=1024,
32
+ min_length=20,
33
+ length_penalty=2.0,
34
+ early_stopping=True,
35
+ num_beams=4,
36
+ bad_words_callback=[lambda x: True if 'User:' in str(x) else False]
37
+ )
38
+ output = output[0].tolist()[len(tokens)-1:]
39
+ decoded_output = tokenizer.decode(output, skip_special_tokens=True)
40
 
41
+ history[-1] = (last_user_message, decoded_output)
42
+ return f"AI: {decoded_output}".capitalize()
43
 
44
  st.title("Simple Chat App using DistilBert Model (HuggingFace & Streamlit)")
45