kennyhelsens commited on
Commit
55b5a53
1 Parent(s): 9bcbd92

Inital commit of feedback demo

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. app.py +36 -0
  3. makefile +7 -0
  4. requirements.txt +3 -0
  5. src/__init__.py +0 -0
  6. src/utils.py +98 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/*
2
+ *.pyc
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+
4
+
5
+ from src.utils import summarize_en
6
+
7
+
8
+ st.write("## 👥 Getting User Feedback with Tally forms")
9
+
10
+
11
+ st.write("It's one thing to build performant models, it's another thing to prototype and test their performance in front of real users. By mixing Streamlit & Tally, this has never been easier.")
12
+
13
+ st.write("The Streamlit app below uses google/pegasus-xsum to summarize an article. Run a summary, and share your feedback. 💬")
14
+
15
+ placeholder_text = """Contemporary climate change includes both global warming and its impacts on Earth's weather patterns. There have been previous periods of climate change, but the current changes are distinctly more rapid and not due to natural causes. Instead, they are caused by the emission of greenhouse gases, mostly carbon dioxide (CO2) and methane. Burning fossil fuels for energy use creates most of these emissions. Certain agricultural practices, industrial processes, and forest loss are additional sources. Greenhouse gases are transparent to sunlight, allowing it through to heat the Earth's surface. When the Earth emits that heat as infrared radiation the gases absorb it, trapping the heat near the Earth's surface. As the planet heats up it causes changes like the loss of sunlight-reflecting snow cover, amplifying global warming. From: https://en.wikipedia.org/wiki/Climate_change"""
16
+
17
+ feedback_url = """https://tally.so/embed/w2EX7A?alignLeft=1&hideTitle=1&transparentBackground=1"""
18
+
19
+ with st.form("my_form"):
20
+ st.write("#### Your long text goes here 🇬🇧↘️")
21
+ txt_input = st.text_area("", placeholder_text,height=300)
22
+
23
+ # Every form must have a submit button.
24
+ submitted = st.form_submit_button("🤖 Summarize ")
25
+ if submitted:
26
+ with st.spinner(text="Summarizing text, this could take 10 seconds"):
27
+
28
+
29
+ summary = summarize_en(txt_input)
30
+ st.markdown("### Summary by Pegasus")
31
+ st.markdown(">{}".format(summary))
32
+ # embed streamlit docs in a streamlit app
33
+ components.iframe(feedback_url, height=800, scrolling=True)
34
+
35
+ # st.write("Outside the form")
36
+
makefile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ setup:
2
+ ( \
3
+ python -m venv .venv; \
4
+ . .venv/bin/activate; \
5
+ pip install --upgrade pip; \
6
+ pip install -r requirements.txt; \
7
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers[torch]
2
+ streamlit
3
+ sentencepiece
src/__init__.py ADDED
File without changes
src/utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared utility methods for this module.
3
+ """
4
+
5
+ from ctypes import Array
6
+ import datetime
7
+ import re
8
+
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, MBartForConditionalGeneration, MBartTokenizer, pipeline
10
+ from transformers import PegasusTokenizer, PegasusForConditionalGeneration
11
+
12
+
13
+
14
+ def lowercase_string(string: str) -> str:
15
+ """Returns a lowercased string
16
+ Args:
17
+ string: String to lowercase
18
+ Returns:
19
+ String in lowercase
20
+ """
21
+ if isinstance(string, str):
22
+ return string.lower()
23
+ return None
24
+
25
+
26
+ from functools import lru_cache
27
+
28
+ @lru_cache
29
+ def get_sentiment_pipeline():
30
+ model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
31
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+ sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
34
+ return sentiment_pipeline
35
+
36
+ def score_sentiment(input: str):
37
+ """Score sentiment of an input string with a pretrained Transformers Pipeline
38
+
39
+ Args:
40
+ input (str): Text to be scored
41
+
42
+ Returns:
43
+ tuple: (label, score)
44
+ """
45
+ sentiment_pipeline = get_sentiment_pipeline()
46
+ result = sentiment_pipeline(input.lower())[0]
47
+ # print("label:{0} input:{1}".format(result['label'], input))
48
+ return result['label'], result['score']
49
+
50
+
51
+ @lru_cache
52
+ def get_summarization_pipeline_nl():
53
+
54
+ undisputed_best_model = MBartForConditionalGeneration.from_pretrained(
55
+ "ml6team/mbart-large-cc25-cnn-dailymail-nl"
56
+ )
57
+ tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-cc25")
58
+ summarization_pipeline = pipeline(
59
+ task="summarization",
60
+ model=undisputed_best_model,
61
+ tokenizer=tokenizer,
62
+ )
63
+ summarization_pipeline.model.config.decoder_start_token_id = tokenizer.lang_code_to_id[
64
+ "nl_XX"
65
+ ]
66
+ return summarization_pipeline
67
+
68
+ def summarize_nl(input: str) -> str:
69
+ summarization_pipeline = get_summarization_pipeline_nl()
70
+ summary = summarization_pipeline(
71
+ input,
72
+ do_sample=True,
73
+ top_p=0.75,
74
+ top_k=50,
75
+ # num_beams=4,
76
+ min_length=50,
77
+ early_stopping=True,
78
+ truncation=True,
79
+ )[0]["summary_text"]
80
+ return summary
81
+
82
+
83
+
84
+ @lru_cache
85
+ def get_pegasus():
86
+ model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
87
+ tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
88
+ return model, tokenizer
89
+
90
+ def summarize_en(input: str) -> str:
91
+
92
+ model, tokenizer = get_pegasus()
93
+ inputs = tokenizer(input, max_length=1024, return_tensors="pt")
94
+
95
+ # Generate Summary
96
+ summary_ids = model.generate(inputs["input_ids"])
97
+ result = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
98
+ return result