File size: 3,068 Bytes
97e4014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
import replicate
import os
from transformers import AutoTokenizer, GenerationConfig, AutoModelForSeq2SeqLM
import torch

# Set Replicate API token
with st.sidebar:
    st.title('Dialogue Text Summarization')
    if 'REPLICATE_API_TOKEN' in st.secrets:
        replicate_api = st.secrets['REPLICATE_API_TOKEN']
    else:
        replicate_api = st.text_input('Enter Replicate API token:', type='password')
        if not (replicate_api.startswith('r8_') and len(replicate_api) == 40):
            st.warning('Please enter your Replicate API token.', icon='⚠️')
            st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")

    os.environ['REPLICATE_API_TOKEN'] = replicate_api
    st.subheader("Adjust model parameters")
    min_new_tokens = st.slider('Min new tokens', min_value=1, max_value=256, step=1, value=10)
    temperature = st.slider('Temperature', min_value=0.01, max_value=1.00, step=0.01, value=1.0)
    top_k = st.slider('Top_k', min_value=1, max_value=50, step=1, value=20)
    top_p = st.slider('Top_p', min_value=0.01, max_value=1.00, step=0.01, value=1.0)

# Initialize model and tokenizer
checkpoint = "dtruong46me/train-bart-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)

st.title("Dialogue Text Summarization")
st.caption("Natural Language Processing Project 20232")
st.write("---")

input_text = st.text_area("Dialogue", height=200)

generation_config = GenerationConfig(
    min_new_tokens=min_new_tokens,
    max_new_tokens=320,
    temperature=temperature,
    top_p=top_p,
    top_k=top_k
)

def generate_summary(model, input_text, generation_config, tokenizer):
    prefix = "Summarize the following conversation: \n\n###"
    suffix = "\n\nSummary:"
    input_ids = tokenizer.encode(prefix + input_text + suffix, return_tensors="pt").to(model.device)
    prompt_str = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return prompt_str

def stream_summary(prompt_str, temperature, top_p):
    for event in replicate.stream(
        "snowflake/snowflake-arctic-instruct",
        input={"prompt": prompt_str,
               "prompt_template": r"{prompt}",
               "temperature": temperature,
               "top_p": top_p}):
        yield str(event['output'])

if st.button("Submit"):
    st.write("---")
    st.write("## Summary")

    if not replicate_api:
        st.error("Please enter your Replicate API token!")
    elif not input_text:
        st.error("Please enter a dialogue!")
    else:
        prompt_str = generate_summary(model, input_text, generation_config, tokenizer)
        summary_container = st.empty()

        summary_text = ""
        for output in stream_summary(prompt_str, temperature, top_p):
            summary_text += output
            summary_container.text(summary_text)