--- language: - en widget: - text: "Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic" tags: - seq2seq license: cc-by-nc-sa-4.0 --- # REBEL: Relation Extraction By End-to-end Language generation This is the model card for the Findings of EMNLP 2021 paper REBEL: Relation Extraction By End-to-end Language generation. We present a new linearization aproach and a reframing of Relation Extraction as a seq2seq task. The paper can be found [here](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf). If you use the code, please reference this work in your paper: @inproceedings{huguet-cabot-navigli-2021-rebel, title = "REBEL: Relation Extraction By End-to-end Language generation", author = "Huguet Cabot, Pere-Llu{\'\i}s and Navigli, Roberto", booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021", month = nov, year = "2021", address = "Online and in the Barceló Bávaro Convention Centre, Punta Cana, Dominican Republic", publisher = "Association for Computational Linguistics", url = "https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf", } The original repository for the paper can be found [here](https://github.com/Babelscape/rebel) ## Pipeline usage ```python3 from transformers import pipeline triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large') # We need to use the tokenizer manually since we need special tokens. extracted_text = triplet_extractor.tokenizer.decode(triplet_extractor("Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic", return_tensors=True, return_text=False)[0]["generated_token_ids"]) print(extracted_text) # Function to parse the generated text and extract the triplets def extract_triplets(text): triplets = [] relation = '' for token in text.split(): if token == "": current = 't' if relation != '': triplets.append((subject, relation, object_)) relation = '' subject = '' elif token == "": current = 's' if relation != '': triplets.append((subject, relation, object_)) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token triplets.append((subject, relation, object_)) return triplets extracted_triplets = extract_triplets(extracted_text) print(extracted_triplets) ``` ## Model and Tokenizer using transformers ```python3 from transformers import AutoModelForSeq2SeqLM, AutoTokenizer def extract_triplets(text): triplets = [] relation = '' for token in text.split(): if token == "": current = 't' if relation != '': triplets.append((subject, relation, object_)) relation = '' subject = '' elif token == "": current = 's' if relation != '': triplets.append((subject, relation, object_)) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token triplets.append((subject, relation, object_)) return triplets # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") gen_kwargs = { "max_length": 256, "length_penalty": 0, "num_beams": 3, "num_return_sequences": 3, } # Text to extract triplets from text = 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.' # Tokenizer text model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt') # Generate generated_tokens = model.generate( model_inputs["input_ids"].to(model.device), attention_mask=model_inputs["attention_mask"].to(model.device), **gen_kwargs, ) # Extract text decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) # Extract triplets for idx, sentence in enumerate(decoded_preds): print(f'Prediction triplets sentence {idx}') print(extract_triplets(sentence)) ```