dtruong46me's picture
Update app.py
8fcd344 verified
history blame
4.59 kB
import streamlit as st
import pandas as pd
from transformers import GenerationConfig, BartModel, BartTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
import torch
import time
import sys, os
path = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, path)
from gen_summary import generate_summary
st.title("Dialogue Text Summarization")
st.caption("Natural Language Processing Project 20232")
class StreamlitTextStreamer(TextStreamer):
def __init__(self, tokenizer, st_container, st_info_container, skip_prompt=False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.st_container = st_container
self.st_info_container = st_info_container
self.text = ""
self.start_time = None
self.first_token_time = None
self.total_tokens = 0
def on_finalized_text(self, text: str, stream_end: bool=False):
if self.start_time is None:
self.start_time = time.time()
if self.first_token_time is None and len(text.strip()) > 0:
self.first_token_time = time.time()
self.text += text
self.total_tokens += len(text.split())
self.st_container.markdown("###### " + self.text)
if stream_end:
total_time = time.time() - self.start_time
first_token_wait_time = self.first_token_time - self.start_time if self.first_token_time else None
tokens_per_second = self.total_tokens / total_time if total_time > 0 else None
df = pd.DataFrame(data={
"First token": [first_token_wait_time],
"Total tokens": [self.total_tokens],
"Time taken": [total_time],
"Token per second": [tokens_per_second]
def generate_summary(model, input_text, generation_config, tokenizer, st_container, st_info_container) -> str:
prefix = "Summarize the following conversation: \n###\n"
suffix = "\n### Summary:"
target_length = max(1, int(0.15 * len(input_text.split())))
input_ids = tokenizer.encode(prefix + input_text + f"The generated summary should be around {target_length} words." + suffix, return_tensors="pt")
# Initialize the Streamlit container and streamer
streamer = StreamlitTextStreamer(tokenizer, st_container, st_info_container, skip_special_tokens=True, decoder_start_token_id=3)
model.generate(input_ids, streamer=streamer, do_sample=True, generation_config=generation_config)
except Exception as e:
raise e
with st.sidebar:
checkpoint = st.selectbox("Model", options=[
"Choose model",
st.button("Model detail", use_container_width=True)
st.write("**Generate Options:**")
min_new_tokens = st.number_input("Min new tokens", min_value=1, max_value=64, value=10)
max_new_tokens = st.number_input("Max new tokens", min_value=64, max_value=128, value=64)
temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
top_k = st.number_input("Top_k", min_value=1, max_value=50, step=1, value=20)
top_p = st.number_input("Top_p", min_value=0.01, max_value=1.00, step=0.01, value=1.0)
height = 200
input_text = st.text_area("Dialogue", height=height)
generation_config = GenerationConfig(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if checkpoint=="Choose model":
tokenizer = None
model = None
if checkpoint!="Choose model":
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
if st.button("Submit"):
st.write("## Summary")
if checkpoint=="Choose model":
st.error("Please selece a model!")
if input_text=="":
st.error("Please enter a dialogue!")
# generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer)
st_container = st.empty()
st_info_container = st.empty()
generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer, st_container, st_info_container)