sa_T5_Table_to_text / README.md
Sachinkelenjaguri's picture
readme.md
3ef6882
|
raw
history blame
888 Bytes

import pandas as pd import os import torch from transformers import T5Tokenizer, T5ForConditionalGeneration from transformers.optimization import Adafactor import time import warnings warnings.filterwarnings('ignore')

tokenizer = T5Tokenizer.from_pretrained('Sachinkelenjaguri/sa_T5_Table_to_text') model = T5ForConditionalGeneration.from_pretrained('Sachinkelenjaguri/sa_T5_Table_to_text', return_dict=True) #moving the model to device(GPU/CPU)

def generate(text): model.eval() input_ids = tokenizer.encode("WebNLG:{} ".format(text), return_tensors="pt") # Batch size 1

input_ids.to(dev)

s = time.time() outputs = model.generate(input_ids) gen_text=tokenizer.decode(outputs[0]).replace('','').replace('','') elapsed = time.time() - s print('Generated in {} seconds'.format(str(elapsed)[:4]))

return gen_text

generate(' Russia | leader | Putin')