LukeOLuck's picture
init commit
23e2f58
raw
history blame
585 Bytes
import torch
from torch import nn
from transformers import AutoTokenizer, T5ForConditionalGeneration
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_flan_T5_model(device=device):
"""Creates a HuggingFace all-MiniLM-L6-v2 model.
Args:
device: A torch.device
Returns:
A tuple of the model and tokenizer
"""
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-small').to(device)
return model, tokenizer
# Example usage
model, tokenizer = create_flan_T5_model()