|
import os |
|
import time |
|
import datetime |
|
|
|
import pandas as pd |
|
import seaborn as sns |
|
import numpy as np |
|
import random |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import torch |
|
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler |
|
|
|
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2LMHeadModel |
|
from transformers import AdamW, get_linear_schedule_with_warmup |
|
|
|
import sys |
|
|
|
import pytz |
|
IST = pytz.timezone('Asia/Kolkata') |
|
print(datetime.datetime.now(IST).strftime("%c")) |
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') |
|
|
|
|
|
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False) |
|
|
|
|
|
model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration) |
|
|
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
device = torch.device("cuda") |
|
|
|
model = model.to(device) |
|
|
|
print('Model loaded to GPU') |
|
print(datetime.datetime.now(IST).strftime("%c")) |
|
|
|
output_dir = '/media/data_dump/Ritwik/ggpt/model_save/pytorch_save_files/' |
|
|
|
print('Loading fine-tuned weights') |
|
model = model.from_pretrained(output_dir).to(device) |
|
tokenizer = tokenizer.from_pretrained(output_dir) |
|
|
|
print('Model and tokenizer loaded!') |
|
print(datetime.datetime.now(IST).strftime("%c")) |
|
|
|
model.eval() |
|
|
|
|
|
prompt_list = ['<|startoftext|> Regarding Kashmir I am very confident to say that'] |
|
|
|
for prompt in prompt_list: |
|
|
|
|
|
|
|
print(prompt) |
|
|
|
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) |
|
generated = generated.to(device) |
|
|
|
print(generated) |
|
|
|
sample_outputs = model.generate( |
|
generated, |
|
|
|
do_sample=True, |
|
top_k=50, |
|
max_length = 500, |
|
top_p=0.95, |
|
num_return_sequences=3 |
|
) |
|
|
|
for i, sample_output in enumerate(sample_outputs): |
|
print("{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True))) |
|
|
|
print(datetime.datetime.now(IST).strftime("%c")) |
|
print('\n') |
|
|
|
|