efederici commited on
Commit
47cbc92
1 Parent(s): 290c1b4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +74 -0
README.md CHANGED
@@ -54,6 +54,80 @@ tags = tag(article)
54
  print(tags)
55
  ```
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ### Overview
58
 
59
  - Model: T5 ([it5-small](https://huggingface.co/gsarti/it5-small))
 
54
  print(tags)
55
  ```
56
 
57
+ ## Longer documents
58
+
59
+ Assuming paragraphs are divided by: '\n\n'.
60
+
61
+ ```python
62
+ from transformers import T5ForConditionalGeneration,T5Tokenizer
63
+ import itertools
64
+ import re
65
+
66
+ model = T5ForConditionalGeneration.from_pretrained("efederici/text2tags")
67
+ tokenizer = T5Tokenizer.from_pretrained("efederici/text2tags")
68
+
69
+ article = '''
70
+ Da bambino era preoccupato che al mondo non ci fosse più nulla da scoprire. Ma i suoi stessi studi gli avrebbero dato torto: insieme a James Watson, nel 1953 Francis Crick strutturò il primo modello di DNA, la lunga sequenza di codici che identifica ogni essere vivente, rendendolo unico e diverso da tutti gli altri.
71
+ La scoperta gli valse il Nobel per la Medicina. È uscita in queste settimane per Codice la sua biografia, Francis Crick — Lo scopritore del DNA, scritta da Matt Ridley, che racconta vita e scienza dell'uomo che capì perché siamo fatti così.
72
+ '''
73
+
74
+ def words(text):
75
+ input_str = text
76
+ output_str = re.sub('[^A-Za-z0-9]+', ' ', input_str)
77
+ return output_str.split()
78
+
79
+ def is_subset(text1, text2):
80
+ return all(tag in words(text1.lower()) for tag in text2.split())
81
+
82
+ def cleaning(text, tags):
83
+ return [tag for tag in tags if is_subset(text, tag)]
84
+
85
+ def get_texts(self, text, max_len):
86
+ texts = list(filter(lambda x : x != '', text.split('\n\n')))
87
+ lengths = [len(tokenizer.encode(paragraph)) for paragraph in texts]
88
+ output = []
89
+ for i, par in enumerate(texts):
90
+ index = len(output)
91
+ if index > 0 and lengths[i] + len(tokenizer.encode(output[index-1])) <= max_len:
92
+ output[index-1] = "".join(output[index-1] + par)
93
+ else:
94
+ output.append(par)
95
+ return output
96
+
97
+ def get_tags(self, text, generate_kwargs):
98
+ input_text = 'summarize: ' + text.strip().replace('\n', ' ')
99
+ tokenized_text = tokenizer.encode(input_text, return_tensors="pt")
100
+ with torch.no_grad():
101
+ tags_ids = model.generate(tokenized_text, **generate_kwargs)
102
+
103
+ output = []
104
+ for tags in tags_ids:
105
+ cleaned = cleaning(
106
+ text,
107
+ list(set(tokenizer.decode(tags, skip_special_tokens=True).split(', ')))
108
+ )
109
+ output.append(cleaned)
110
+
111
+ return list(set(itertools.chain(*output)))
112
+
113
+ def tag(self, text, max_len, generate_kwargs):
114
+ texts = self.get_texts(text, max_len)
115
+ all_tags = [self.get_tags(text, generate_kwargs) for text in texts]
116
+ flatten_tags = itertools.chain(*all_tags)
117
+ return list(set(flatten_tags))
118
+
119
+ params = {
120
+ "min_length": 0,
121
+ "max_length": 30,
122
+ "no_repeat_ngram_size": 2,
123
+ "num_beams": 4,
124
+ "early_stopping": True,
125
+ "num_return_sequences": 4,
126
+ }
127
+ tags = tag(article, 512, params)
128
+ print(tags)
129
+ ```
130
+
131
  ### Overview
132
 
133
  - Model: T5 ([it5-small](https://huggingface.co/gsarti/it5-small))