from transformers import GPT2Tokenizer, GPT2LMHeadModel
model_load_path = 'clam004/emerg-intent-consistent-good-gpt2-xl-v2'
tokenizer_name = 'gpt2-xl'
tokenizer = GPT2Tokenizer.from_pretrained(
tokenizer_name,
pad_token='<|endoftext|>',
padding_side = 'left',
)
model = GPT2LMHeadModel.from_pretrained(
model_load_path,
cache_dir = None,
pad_token_id=tokenizer.eos_token_id,
)
few_shot_pretext = \
'<|?|>\n[human]: i just slipped fell hit my head now im bleeding.<|emergency?|>True'+\
'<|?|>\n[human]: aluminium foil is bad, not good.<|emergency?|>False'+\
'<|?|>\n[human]: i dont want to live anymore i want to kill myself.<|emergency?|>True'+\
'<|?|>\n[human]: im dying for some coffee and donuts but i have diabetes.<|emergency?|>False'+\
'<|?|>\n[human]: i just got back from the hospital.<|emergency?|>False'+\
'<|?|>\n[human]: im fasting for my blood test tomorrow to see if my meds are working.<|emergency?|>False'+\
'<|?|>\n[human]: '
query_text = "I want to die, I am worthless anyways."
few_shot_prompt = few_shot_pretext + query_text + '<|emergency?|>'
print(repr(few_shot_prompt))
print('-'*50)
prompt_dic = tokenizer(few_shot_prompt,return_tensors="pt")
prompt_ids = prompt_dic.input_ids
prompt_mask = prompt_dic.attention_mask
prompt_len = prompt_ids.shape[1]
prompt_ids = prompt_ids.to(model.device)
prompt_mask = prompt_mask.to(model.device)
output_ids = model.generate(
prompt_ids,
attention_mask = prompt_mask,
max_length = prompt_len + 1,
)
generated_text = tokenizer.batch_decode(output_ids[:,-1])
print(generated_text[0])
print('-'*50)
query_text = "I am dying for a cup of coffee."
few_shot_prompt = few_shot_pretext + query_text + '<|emergency?|>'
print(repr(few_shot_prompt))
print('-'*50)
prompt_dic = tokenizer(few_shot_prompt,return_tensors="pt")
prompt_ids = prompt_dic.input_ids
prompt_mask = prompt_dic.attention_mask
prompt_len = prompt_ids.shape[1]
prompt_ids = prompt_ids.to(model.device)
prompt_mask = prompt_mask.to(model.device)
output_ids = model.generate(
prompt_ids,
attention_mask = prompt_mask,
max_length = prompt_len + 1,
)
generated_text = tokenizer.batch_decode(output_ids[:,-1])
print(generated_text[0])
print('-'*50)