File size: 4,979 Bytes
955c8a8
 
98a8436
b48fa68
fc5a865
e07cf78
950ede6
e0b483f
 
2262546
e07cf78
 
fc5a865
 
6b4db0c
955c8a8
6b4db0c
 
fc5a865
6b4db0c
955c8a8
6b4db0c
955c8a8
 
 
 
 
 
 
 
 
 
fc5a865
 
 
955c8a8
fc5a865
955c8a8
e07cf78
944d9b4
950ede6
 
 
 
 
 
 
 
ccb93ff
944d9b4
b48fa68
944d9b4
56b6b47
 
 
 
944d9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d872090
944d9b4
 
 
 
e0b483f
944d9b4
 
 
 
 
 
 
 
 
 
 
ccb93ff
d17a7aa
 
955c8a8
e07cf78
955c8a8
 
ccb93ff
d872090
e07cf78
955c8a8
 
fc5a865
955c8a8
d17a7aa
 
 
 
d872090
 
 
 
 
 
 
 
9027395
61aa0aa
9027395
61aa0aa
9027395
d872090
 
 
 
 
 
 
 
 
 
 
 
 
955c8a8
944d9b4
 
d872090
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import transformers
import streamlit as st
import nltk
from nltk import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json
import numpy as np
from sentence_transformers import SentenceTransformer

nltk.download('punkt')
with open('testbook.json') as f:
    test_book = json.load(f)
tokenizer = AutoTokenizer.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")

@st.cache_resource
def load_model(model_name):
    nltk.download('punkt')
    sentence_transformer_model = SentenceTransformer("sentence-transformers/all-roberta-large-v1")
    model = AutoModelForSeq2SeqLM.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")
    return sentence_transformer_model, model

sentence_transformer_model, model = load_model("UNIST-Eunchan/bart-dnc-booksum")

def infer(input_ids, max_length, temperature, top_k, top_p):

    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1,
        num_beams=4,
        no_repeat_ngram_size=2
    )
    
    return output_sequences


def cos_similarity(v1, v2):
    dot_product = np.dot(v1, v2)
    l2_norm = (np.sqrt(sum(np.square(v1))) * np.sqrt(sum(np.square(v2))))
    similarity = dot_product / l2_norm     
    
    return similarity


@st.cache_data
def chunking(book_text):
    sentences = sent_tokenize(book_text)
    segments = [] 
    token_lens = []

    for sent_i_th in sentences:
        token_lens.append(len(tokenizer.tokenize(sent_i_th)))
    #sentences, token_lens
    current_segment = ""    
    total_token_lens = 0
    for i in range(len(sentences)):
        
        if total_token_lens < 512:
            total_token_lens += token_lens[i]
            current_segment += (sentences[i] + " ")
            
        elif total_token_lens > 768:
            segments.append(current_segment)
            current_segment = sentences[i]
            total_token_lens = token_lens[i]
    
        else:
            #make next_pseudo_segment
            next_pseudo_segment = ""
            next_token_len = 0
            for t in range(10):
                if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512):
                    next_token_len += token_lens[i+t] 
                    next_pseudo_segment += sentences[i+t] 
                
            embs = sentence_transformer_model.encode([current_segment, next_pseudo_segment, sentences[i]]) # current, next, sent
            if cos_similarity(embs[1],embs[2]) > cos_similarity(embs[0],embs[2]):
                segments.append(current_segment)
                current_segment = sentences[i]
                total_token_lens = token_lens[i]
            else: 
                total_token_lens += token_lens[i]
                current_segment += (sentences[i] + " ")

    return segments


book_index = 0
_book = test_book[book_index]['book']

#prompts
st.title("Book Summarization 📚")
st.write("The almighty king of text generation, GPT-2 comes in four available sizes, only three of which have been publicly made available. Feared for its fake news generation capabilities, it currently stands as the most syntactically coherent model. A direct successor to the original GPT, it reinforces the already established pre-training/fine-tuning killer duo. From the paper: Language Models are Unsupervised Multitask Learners by Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever.")

#book_index = st.sidebar.slider("Select Book Example", value = 0,min_value = 0, max_value=4)
sent = st.text_area("Text", _book[:512], height = 550)
max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92)


chunked_segments = chunking(_book)


def generate_output(test_samples):
    inputs = tokenizer(
        test_samples,
        padding=max_length,
        truncation=True,
        max_length=1024,
        return_tensors="pt",
    )
    
    input_ids = inputs.input_ids
    
    attention_mask = inputs.attention_mask
    
    outputs = model.generate(input_ids,
                             max_length = 256,
                             min_length=32,
                             top_p = 0.92,
                             num_beams=5,
                             no_repeat_ngram_size=2,
                             attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


chunked_segments = chunking(test_book[0]['book'])


for segment in range(len(chunked_segments)):
        
    summaries = generate_output(segment)
    st.write(summaries[-1])