Edit model card

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

from transformers import GPT2Tokenizer, GPT2LMHeadModel

model_load_path = 'clam004/emerg-intent-consistent-good-gpt2-xl-v2' 
tokenizer_name = 'gpt2-xl'

tokenizer = GPT2Tokenizer.from_pretrained(
    tokenizer_name,
    pad_token='<|endoftext|>',
    padding_side = 'left',
)

model = GPT2LMHeadModel.from_pretrained(
    model_load_path,
    cache_dir = None, 
    pad_token_id=tokenizer.eos_token_id,
)

few_shot_pretext = \
'<|?|>\n[human]: i just slipped fell hit my head now im bleeding.<|emergency?|>True'+\
'<|?|>\n[human]: aluminium foil is bad, not good.<|emergency?|>False'+\
'<|?|>\n[human]: i dont want to live anymore i want to kill myself.<|emergency?|>True'+\
'<|?|>\n[human]: im dying for some coffee and donuts but i have diabetes.<|emergency?|>False'+\
'<|?|>\n[human]: i just got back from the hospital.<|emergency?|>False'+\
'<|?|>\n[human]: im fasting for my blood test tomorrow to see if my meds are working.<|emergency?|>False'+\
'<|?|>\n[human]: '

###############################################################

query_text = "I want to die, I am worthless anyways."

few_shot_prompt = few_shot_pretext + query_text + '<|emergency?|>'

print(repr(few_shot_prompt))
print('-'*50)

prompt_dic = tokenizer(few_shot_prompt,return_tensors="pt")
prompt_ids = prompt_dic.input_ids
prompt_mask = prompt_dic.attention_mask
prompt_len = prompt_ids.shape[1]

prompt_ids = prompt_ids.to(model.device)
prompt_mask = prompt_mask.to(model.device)

output_ids = model.generate(
    prompt_ids,
    attention_mask = prompt_mask,
    max_length = prompt_len + 1,
)

generated_text = tokenizer.batch_decode(output_ids[:,-1])

print(generated_text[0]) #'True'
print('-'*50)

###############################################################

query_text = "I am dying for a cup of coffee."

few_shot_prompt = few_shot_pretext + query_text + '<|emergency?|>'

print(repr(few_shot_prompt))
print('-'*50)

prompt_dic = tokenizer(few_shot_prompt,return_tensors="pt")
prompt_ids = prompt_dic.input_ids
prompt_mask = prompt_dic.attention_mask
prompt_len = prompt_ids.shape[1]

prompt_ids = prompt_ids.to(model.device)
prompt_mask = prompt_mask.to(model.device)

output_ids = model.generate(
    prompt_ids,
    attention_mask = prompt_mask,
    max_length = prompt_len + 1,
)

generated_text = tokenizer.batch_decode(output_ids[:,-1])

print(generated_text[0]) #'False'
print('-'*50)
Downloads last month
1