PereLluis13 commited on
Commit
79439e4
1 Parent(s): de3640d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -38
README.md CHANGED
@@ -21,11 +21,17 @@ language:
21
  widget:
22
  - text: >-
23
  The Red Hot Chili Peppers were formed in Los Angeles by Kiedis, Flea, guitarist Hillel Slovak and drummer Jack Irons.
 
 
 
 
 
 
24
  tags:
25
  - seq2seq
26
  - relation-extraction
27
-
28
  license: cc-by-nc-sa-4.0
 
29
  ---
30
  # RED<sup>FM</sup>: a Filtered and Multilingual Relation Extraction Dataset
31
 
@@ -53,31 +59,36 @@ Be aware that the inference widget at the right does not output special tokens,
53
  ```python
54
  from transformers import pipeline
55
 
56
- triplet_extractor = pipeline('text2text-generation', model='Babelscape/mrebel-large', tokenizer='Babelscape/mrebel-large')
57
  # We need to use the tokenizer manually since we need special tokens.
58
- extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor("The Red Hot Chili Peppers were formed in Los Angeles by Kiedis, Flea, guitarist Hillel Slovak and drummer Jack Irons.", return_tensors=True, return_text=False)[0]["generated_token_ids"]])
59
  print(extracted_text[0])
60
  # Function to parse the generated text and extract the triplets
61
- def extract_triplets(text):
62
  triplets = []
63
- relation, subject, relation, object_ = '', '', '', ''
64
  text = text.strip()
65
  current = 'x'
66
- for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
67
- if token == "<triplet>":
 
 
68
  current = 't'
69
  if relation != '':
70
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
71
  relation = ''
72
  subject = ''
73
- elif token == "<subj>":
74
- current = 's'
75
- if relation != '':
76
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
77
- object_ = ''
78
- elif token == "<obj>":
79
- current = 'o'
80
- relation = ''
 
 
 
81
  else:
82
  if current == 't':
83
  subject += ' ' + token
@@ -85,10 +96,10 @@ def extract_triplets(text):
85
  object_ += ' ' + token
86
  elif current == 'o':
87
  relation += ' ' + token
88
- if subject != '' and relation != '' and object_ != '':
89
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
90
  return triplets
91
- extracted_triplets = extract_triplets(extracted_text[0])
92
  print(extracted_triplets)
93
  ```
94
 
@@ -97,26 +108,31 @@ print(extracted_triplets)
97
  ```python
98
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
99
 
100
- def extract_triplets(text):
101
  triplets = []
102
- relation, subject, relation, object_ = '', '', '', ''
103
  text = text.strip()
104
  current = 'x'
105
- for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
106
- if token == "<triplet>":
 
 
107
  current = 't'
108
  if relation != '':
109
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
110
  relation = ''
111
  subject = ''
112
- elif token == "<subj>":
113
- current = 's'
114
- if relation != '':
115
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
116
- object_ = ''
117
- elif token == "<obj>":
118
- current = 'o'
119
- relation = ''
 
 
 
120
  else:
121
  if current == 't':
122
  subject += ' ' + token
@@ -124,18 +140,19 @@ def extract_triplets(text):
124
  object_ += ' ' + token
125
  elif current == 'o':
126
  relation += ' ' + token
127
- if subject != '' and relation != '' and object_ != '':
128
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
129
  return triplets
130
 
131
  # Load model and tokenizer
132
- tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
133
- model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
134
  gen_kwargs = {
135
  "max_length": 256,
136
  "length_penalty": 0,
137
  "num_beams": 3,
138
  "num_return_sequences": 3,
 
139
  }
140
 
141
  # Text to extract triplets from
@@ -148,6 +165,7 @@ model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, re
148
  generated_tokens = model.generate(
149
  model_inputs["input_ids"].to(model.device),
150
  attention_mask=model_inputs["attention_mask"].to(model.device),
 
151
  **gen_kwargs,
152
  )
153
 
@@ -157,5 +175,9 @@ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=Fal
157
  # Extract triplets
158
  for idx, sentence in enumerate(decoded_preds):
159
  print(f'Prediction triplets sentence {idx}')
160
- print(extract_triplets(sentence))
161
- ```
 
 
 
 
 
21
  widget:
22
  - text: >-
23
  The Red Hot Chili Peppers were formed in Los Angeles by Kiedis, Flea, guitarist Hillel Slovak and drummer Jack Irons.
24
+ example_title: English
25
+ inference:
26
+ parameters:
27
+ decoder_start_token_id: 250058
28
+ src_lang: en_XX
29
+ tgt_lang: <triplet>
30
  tags:
31
  - seq2seq
32
  - relation-extraction
 
33
  license: cc-by-nc-sa-4.0
34
+ pipeline_tag: translation
35
  ---
36
  # RED<sup>FM</sup>: a Filtered and Multilingual Relation Extraction Dataset
37
 
 
59
  ```python
60
  from transformers import pipeline
61
 
62
+ triplet_extractor = pipeline('translation_xx_to_yy', model='Babelscape/mrebel-base', tokenizer='Babelscape/mrebel-base')
63
  # We need to use the tokenizer manually since we need special tokens.
64
+ extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor("The Red Hot Chili Peppers were formed in Los Angeles by Kiedis, Flea, guitarist Hillel Slovak and drummer Jack Irons.", decoder_start_token_id=250058, src_lang="en_XX", tgt_lang="<triplet>", return_tensors=True, return_text=False)[0]["translation_token_ids"]]) # change en_XX for the language of the source.
65
  print(extracted_text[0])
66
  # Function to parse the generated text and extract the triplets
67
+ def extract_triplets_typed(text):
68
  triplets = []
69
+ relation = ''
70
  text = text.strip()
71
  current = 'x'
72
+ subject, relation, object_, object_type, subject_type = '','','','',''
73
+
74
+ for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
75
+ if token == "<triplet>" or token == "<relation>":
76
  current = 't'
77
  if relation != '':
78
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
79
  relation = ''
80
  subject = ''
81
+ elif token.startswith("<") and token.endswith(">"):
82
+ if current == 't' or current == 'o':
83
+ current = 's'
84
+ if relation != '':
85
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
86
+ object_ = ''
87
+ subject_type = token[1:-1]
88
+ else:
89
+ current = 'o'
90
+ object_type = token[1:-1]
91
+ relation = ''
92
  else:
93
  if current == 't':
94
  subject += ' ' + token
 
96
  object_ += ' ' + token
97
  elif current == 'o':
98
  relation += ' ' + token
99
+ if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
100
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
101
  return triplets
102
+ extracted_triplets = extract_triplets_typed(extracted_text[0])
103
  print(extracted_triplets)
104
  ```
105
 
 
108
  ```python
109
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
110
 
111
+ def extract_triplets_typed(text):
112
  triplets = []
113
+ relation = ''
114
  text = text.strip()
115
  current = 'x'
116
+ subject, relation, object_, object_type, subject_type = '','','','',''
117
+
118
+ for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
119
+ if token == "<triplet>" or token == "<relation>":
120
  current = 't'
121
  if relation != '':
122
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
123
  relation = ''
124
  subject = ''
125
+ elif token.startswith("<") and token.endswith(">"):
126
+ if current == 't' or current == 'o':
127
+ current = 's'
128
+ if relation != '':
129
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
130
+ object_ = ''
131
+ subject_type = token[1:-1]
132
+ else:
133
+ current = 'o'
134
+ object_type = token[1:-1]
135
+ relation = ''
136
  else:
137
  if current == 't':
138
  subject += ' ' + token
 
140
  object_ += ' ' + token
141
  elif current == 'o':
142
  relation += ' ' + token
143
+ if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
144
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
145
  return triplets
146
 
147
  # Load model and tokenizer
148
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/mrebel-base", src_lang="en_XX", "tgt_lang": "tp_XX") # Here we set English as source language. To change the source language just change it here or swap the first token of the input for your desired language
149
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-base")
150
  gen_kwargs = {
151
  "max_length": 256,
152
  "length_penalty": 0,
153
  "num_beams": 3,
154
  "num_return_sequences": 3,
155
+ "forced_bos_token_id": None,
156
  }
157
 
158
  # Text to extract triplets from
 
165
  generated_tokens = model.generate(
166
  model_inputs["input_ids"].to(model.device),
167
  attention_mask=model_inputs["attention_mask"].to(model.device),
168
+ decoder_start_token_id = self.tokenizer.convert_tokens_to_ids("tp_XX"),
169
  **gen_kwargs,
170
  )
171
 
 
175
  # Extract triplets
176
  for idx, sentence in enumerate(decoded_preds):
177
  print(f'Prediction triplets sentence {idx}')
178
+ print(extract_triplets_typed(sentence))
179
+ ```
180
+
181
+ ## License
182
+
183
+ This model is licensed under the CC BY-SA 4.0 license. The text of the license can be found [here](https://creativecommons.org/licenses/by-nc-sa/4.0/).