Nopphakorn fabiochiu commited on
Commit
27a2b18
·
0 Parent(s):

Duplicate from fabiochiu/title-generation

Browse files

Co-authored-by: Fabio Chiusano <fabiochiu@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +27 -0
  2. README.md +14 -0
  3. app.py +101 -0
  4. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Title Generation
3
+ emoji: 📈
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: streamlit
7
+ sdk_version: 1.2.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: fabiochiu/title-generation
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import nltk
4
+ import math
5
+ import torch
6
+
7
+ model_name = "fabiochiu/t5-base-medium-title-generation"
8
+ max_input_length = 512
9
+
10
+ st.header("Generate candidate titles for articles")
11
+
12
+ st_model_load = st.text('Loading title generator model...')
13
+
14
+ @st.cache(allow_output_mutation=True)
15
+ def load_model():
16
+ print("Loading model...")
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
+ nltk.download('punkt')
20
+ print("Model loaded!")
21
+ return tokenizer, model
22
+
23
+ tokenizer, model = load_model()
24
+ st.success('Model loaded!')
25
+ st_model_load.text("")
26
+
27
+ with st.sidebar:
28
+ st.header("Model parameters")
29
+ if 'num_titles' not in st.session_state:
30
+ st.session_state.num_titles = 5
31
+ def on_change_num_titles():
32
+ st.session_state.num_titles = num_titles
33
+ num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
34
+ if 'temperature' not in st.session_state:
35
+ st.session_state.temperature = 0.7
36
+ def on_change_temperatures():
37
+ st.session_state.temperature = temperature
38
+ temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
39
+ st.markdown("_High temperature means that results are more random_")
40
+
41
+ if 'text' not in st.session_state:
42
+ st.session_state.text = ""
43
+ st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)
44
+
45
+ def generate_title():
46
+ st.session_state.text = st_text_area
47
+
48
+ # tokenize text
49
+ inputs = ["summarize: " + st_text_area]
50
+ inputs = tokenizer(inputs, return_tensors="pt")
51
+
52
+ # compute span boundaries
53
+ num_tokens = len(inputs["input_ids"][0])
54
+ print(f"Input has {num_tokens} tokens")
55
+ max_input_length = 500
56
+ num_spans = math.ceil(num_tokens / max_input_length)
57
+ print(f"Input has {num_spans} spans")
58
+ overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
59
+ spans_boundaries = []
60
+ start = 0
61
+ for i in range(num_spans):
62
+ spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
63
+ start -= overlap
64
+ print(f"Span boundaries are {spans_boundaries}")
65
+ spans_boundaries_selected = []
66
+ j = 0
67
+ for _ in range(num_titles):
68
+ spans_boundaries_selected.append(spans_boundaries[j])
69
+ j += 1
70
+ if j == len(spans_boundaries):
71
+ j = 0
72
+ print(f"Selected span boundaries are {spans_boundaries_selected}")
73
+
74
+ # transform input with spans
75
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
76
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
77
+
78
+ inputs = {
79
+ "input_ids": torch.stack(tensor_ids),
80
+ "attention_mask": torch.stack(tensor_masks)
81
+ }
82
+
83
+ # compute predictions
84
+ outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
85
+ decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
+ predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
87
+
88
+ st.session_state.titles = predicted_titles
89
+
90
+ # generate title button
91
+ st_generate_button = st.button('Generate title', on_click=generate_title)
92
+
93
+ # title generation labels
94
+ if 'titles' not in st.session_state:
95
+ st.session_state.titles = []
96
+
97
+ if len(st.session_state.titles) > 0:
98
+ with st.container():
99
+ st.subheader("Generated titles")
100
+ for title in st.session_state.titles:
101
+ st.markdown("__" + title + "__")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ nltk
2
+ torch
3
+ transformers