rebel-large / README.md
PereLluis13's picture
Update README.md
9bf109f
metadata
language:
  - en
widget:
  - text: >-
      Punta Cana is a resort town in the municipality of Higuey, in La
      Altagracia Province, the eastern most province of the Dominican Republic
tags:
  - seq2seq
license: cc-by-nc-sa-4.0

REBEL: Relation Extraction By End-to-end Language generation

This is the model card for the Findings of EMNLP 2021 paper REBEL: Relation Extraction By End-to-end Language generation. We present a new linearization aproach and a reframing of Relation Extraction as a seq2seq task. The paper can be found here. If you use the code, please reference this work in your paper:

@inproceedings{huguet-cabot-navigli-2021-rebel,
title = "REBEL: Relation Extraction By End-to-end Language generation",
author = "Huguet Cabot, Pere-Llu{\'\i}s  and
  Navigli, Roberto",
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021",
month = nov,
year = "2021",
address = "Online and in the Barceló Bávaro Convention Centre, Punta Cana, Dominican Republic",
publisher = "Association for Computational Linguistics",
url = "https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf",
}

The original repository for the paper can be found here

Pipeline usage

from transformers import pipeline

triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
# We need to use the tokenizer manually since we need special tokens.
extracted_text = triplet_extractor.tokenizer.decode(triplet_extractor("Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic", return_tensors=True, return_text=False)[0]["generated_token_ids"])
print(extracted_text)
# Function to parse the generated text and extract the triplets
def extract_triplets(text):
    triplets = []
    relation = ''
    for token in text.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append((subject, relation, object_))
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append((subject, relation, object_))
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    triplets.append((subject, relation, object_))
    return triplets
extracted_triplets = extract_triplets(extracted_text)
print(extracted_triplets)

Model and Tokenizer using transformers


from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

def extract_triplets(text):
    triplets = []
    relation = ''
    for token in text.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append((subject, relation, object_))
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append((subject, relation, object_))
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    triplets.append((subject, relation, object_))
    return triplets

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
gen_kwargs = {
    "max_length": 256,
    "length_penalty": 0,
    "num_beams": 3,
    "num_return_sequences": 3,
}

# Text to extract triplets from
text = 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.'

# Tokenizer text
model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')

# Generate
generated_tokens = model.generate(
    model_inputs["input_ids"].to(model.device),
    attention_mask=model_inputs["attention_mask"].to(model.device),
    **gen_kwargs,
)

# Extract text
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

# Extract triplets
for idx, sentence in enumerate(decoded_preds):
    print(f'Prediction triplets sentence {idx}')
    print(extract_triplets(sentence))