craigchen commited on
Commit
85e839a
·
1 Parent(s): f44235a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import math
 
4
  import torch
5
 
6
  model_name = "craigchen/BART-139M-ecommerce-customer-service-anwser-to-query-generation"
@@ -82,7 +83,7 @@ def generate_title():
82
  # compute predictions
83
  outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
84
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
85
- predicted_titles = decoded_outputs[0]
86
 
87
  st.session_state.titles = predicted_titles
88
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import math
4
+ import nltk
5
  import torch
6
 
7
  model_name = "craigchen/BART-139M-ecommerce-customer-service-anwser-to-query-generation"
 
83
  # compute predictions
84
  outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
85
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
+ predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
87
 
88
  st.session_state.titles = predicted_titles
89