|
import torch |
|
from torch import nn |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def create_flan_T5_model(device=device): |
|
"""Creates a HuggingFace all-MiniLM-L6-v2 model. |
|
|
|
Args: |
|
device: A torch.device |
|
Returns: |
|
A tuple of the model and tokenizer |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small') |
|
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-small').to(device) |
|
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = create_flan_T5_model() |
|
|