PereLluis13 commited on
Commit
88fdd1b
1 Parent(s): 2f235b9
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from datasets import load_dataset
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ from time import time
5
+ import torch
6
+
7
+ @st.cache(
8
+ allow_output_mutation=True,
9
+ hash_funcs={
10
+ AutoTokenizer: lambda x: None,
11
+ AutoModelForSeq2SeqLM: lambda x: None,
12
+ },
13
+ suppress_st_warning=True
14
+ )
15
+ def load_models():
16
+ st_time = time()
17
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
18
+ print("+++++ loading Model", time() - st_time)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
20
+ if torch.cuda.is_available():
21
+ _ = model.to("cuda:0") # comment if no GPU available
22
+ _ = model.eval()
23
+ print("+++++ loaded model", time() - st_time)
24
+ dataset = load_dataset('Babelscape/rebel-dataset', split="train[:1%]")
25
+ return (tokenizer, model, dataset)
26
+
27
+ def extract_triplets(text):
28
+ triplets = []
29
+ relation = ''
30
+ for token in text.split():
31
+ if token == "<triplet>":
32
+ current = 't'
33
+ if relation != '':
34
+ triplets.append((subject, relation, object_))
35
+ relation = ''
36
+ subject = ''
37
+ elif token == "<subj>":
38
+ current = 's'
39
+ if relation != '':
40
+ triplets.append((subject, relation, object_))
41
+ object_ = ''
42
+ elif token == "<obj>":
43
+ current = 'o'
44
+ relation = ''
45
+ else:
46
+ if current == 't':
47
+ subject += ' ' + token
48
+ elif current == 's':
49
+ object_ += ' ' + token
50
+ elif current == 'o':
51
+ relation += ' ' + token
52
+ triplets.append((subject, relation, object_))
53
+ return triplets
54
+
55
+
56
+ tokenizer, model, dataset = load_models()
57
+
58
+ agree = st.checkbox('Free input', False)
59
+ if agree:
60
+ text = st.text_input('Input text', 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.')
61
+ print(text)
62
+ else:
63
+ dataset_example = st.slider('dataset id', 0, 1000, 0)
64
+ text = dataset[dataset_example]['context']
65
+ length_penalty = st.slider('length_penalty', 0, 10, 0)
66
+ num_beams = st.slider('num_beams', 1, 20, 3)
67
+ num_return_sequences = st.slider('num_return_sequences', 1, num_beams, 2)
68
+
69
+ gen_kwargs = {
70
+ "max_length": 256,
71
+ "length_penalty": length_penalty,
72
+ "num_beams": num_beams,
73
+ "num_return_sequences": num_return_sequences,
74
+ }
75
+
76
+ model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')
77
+ generated_tokens = model.generate(
78
+ model_inputs["input_ids"].to(model.device),
79
+ attention_mask=model_inputs["attention_mask"].to(model.device),
80
+ **gen_kwargs,
81
+ )
82
+
83
+ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
84
+ st.title('Input text')
85
+
86
+ st.write(text)
87
+
88
+ if not agree:
89
+ st.title('Silver output')
90
+ st.write(dataset[dataset_example]['triplets'])
91
+ st.write(extract_triplets(dataset[dataset_example]['triplets']))
92
+
93
+ st.title('Prediction text')
94
+ decoded_preds = [text.replace('<s>', '').replace('</s>', '').replace('<pad>', '') for text in decoded_preds]
95
+ st.write(decoded_preds)
96
+
97
+ for idx, sentence in enumerate(decoded_preds):
98
+ st.title(f'Prediction triplets sentence {idx}')
99
+ st.write(extract_triplets(sentence))