gandhi-gpt / code /gpt-run.py
ritwikm's picture
added code files
b7c468b
raw
history blame
3.27 kB
import os
import time
import datetime
import pandas as pd
import seaborn as sns
import numpy as np
import random
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup
import sys
import pytz
IST = pytz.timezone('Asia/Kolkata')
print(datetime.datetime.now(IST).strftime("%c"))
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') #gpt2-medium
# I'm not really doing anything with the config buheret
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
# instantiate the model
model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration)
# this step is necessary because I've added some tokens (bos_token, etc) to the embeddings
# otherwise the tokenizer and model tensors won't match up
model.resize_token_embeddings(len(tokenizer))
# Tell pytorch to run this model on the GPU.
device = torch.device("cuda")
model = model.to(device)
print('Model loaded to GPU')
print(datetime.datetime.now(IST).strftime("%c"))
output_dir = '/media/data_dump/Ritwik/ggpt/model_save/pytorch_save_files/'
print('Loading fine-tuned weights')
model = model.from_pretrained(output_dir).to(device)
tokenizer = tokenizer.from_pretrained(output_dir)
print('Model and tokenizer loaded!')
print(datetime.datetime.now(IST).strftime("%c"))
model.eval()
# prompt_list = ['<|startoftext|> Regarding Kashmir I am very confident to say that','<|startoftext|> I wanted to save bhagat singh but','<|startoftext|> I wanted to save bhagat singh but fortunately','<|startoftext|> I wanted to save bhagat singh but unfortunately','<|startoftext|> Reporter: What is your biggest fear? Gandhi:','<|startoftext|> Question) What is your biggest fear?','<|startoftext|> Regarding Muslims and Islam I strongly believe that','<|startoftext|> I wish to say that','<|startoftext|> I chose Nehru over Patel for Prime Minister because','<|startoftext|> During my experiments with truth I observed that','<|startoftext|> My opinion on the negroes of Africa is that']
prompt_list = ['<|startoftext|> Regarding Kashmir I am very confident to say that']
for prompt in prompt_list:
# prompt = "<|startoftext|> Regarding Kashmir I am very confident to say that"
print(prompt)
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
generated = generated.to(device)
print(generated)
sample_outputs = model.generate(
generated,
# bos_token_id=random.randint(1,30000),
do_sample=True,
top_k=50,
max_length = 500,
top_p=0.95,
num_return_sequences=3
)
for i, sample_output in enumerate(sample_outputs):
print("{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))
print(datetime.datetime.now(IST).strftime("%c"))
print('\n')