A Text-to-Triple Model

Base Model: Flan-T5-Large by Google

Base Dataset: WikiOFGraph (containing 5.85M pairs of high-quality text-triples)

Trained by Patrick Jiang @ UIUC

Wandb Training Report (Dec 5, 2024)

Example Input:

"William Gerald Standridge (November 27, 1953 – April 12, 2014) was an American stock car racing driver. He was a competitor in the NASCAR Winston Cup Series and Busch Series."

Output:

(S> William gerald standridge| P> Nationality| O> American),
(S> William gerald standridge| P> Occupation| O> Stock car racing driver),
(S> William gerald standridge| P> Competitor| O> Busch series),
(S> William gerald standridge| P> Competitor| O> Nascar winston cup series),
(S> William gerald standridge| P> Birth date| O> November 27, 1953),
(S> William gerald standridge| P> Death date| O> April 12, 2014)

How to Run?

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

def generate_triples(input_text: str, model_path: str = "pat-jj/text2triple-flan-t5"):
    # Initialize tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained(model_path)
    model = T5ForConditionalGeneration.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16  # Use bfloat16 for efficiency
    )
    
    # Tokenize input with proper padding and attention mask
    inputs = tokenizer(
        input_text,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )
    
    # Move inputs to the same device as model
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    # Generate with better parameters
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=512,
            num_beams=4,  # Use beam search
            early_stopping=True,
            length_penalty=0.6,  # Penalize very long outputs
            use_cache=True  # Use KV cache for faster generation
        )
    
    # Decode and return the generated triples
    generated_triples = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_triples

Example usage

input_text = """Albert Einstein was born in Ulm, Germany in 1879. He developed the theory of relativity and won the Nobel Prize in Physics in 1921.
Einstein worked as a professor at Princeton University until his death in 1955."""

generated_triples = generate_triples(input_text)
print("Generated triples:", generated_triples)

Output:

Generated triples: (S> Albert einstein| P> Birth place| O> Ulm, germany), (S> Albert einstein| P> Birth year| O> 1879), (S> Albert einstein| P> Award| O> Nobel prize in physics), (S> Albert einstein| P> Death year| O> 1955), (S> Albert einstein| P> Occupation| O> Professor), (S> Albert einstein| P> Workplace| O> Princeton university)

Paper of WikiOfGraph dataset:

Daehee Kim et al., "Ontology-Free General-Domain Knowledge Graph-to-Text Generation Dataset Synthesis using Large Language Model", 2024.

Cite This Model

@misc {patrick_jiang_2024,
    author       = { {Patrick Jiang} },
    title        = { text2triple-flan-t5 (Revision df1323c) },
    year         = 2024,
    url          = { https://huggingface.co/pat-jj/text2triple-flan-t5 },
    doi          = { 10.57967/hf/3722 },
    publisher    = { Hugging Face }
}
Downloads last month
55
Safetensors
Model size
783M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for pat-jj/text2triple-flan-t5

Finetuned
(106)
this model