snli-5-adapter / run_inference.py
swarnadeep's picture
Upload 2 files
3ff9cd0 verified
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from tqdm import tqdm
def generate_response(model,tokenizer,setup_prompt):
model_inputs = tokenizer(setup_prompt,return_tensors = "pt",device_map="auto")#.to("mps")
output = model.generate(**model_inputs , max_length = 1024, pad_token_id= tokenizer.eos_token_id,
eos_token_id= tokenizer.eos_token_id)
question_to_claims = tokenizer.decode(output[0], skip_special_tokens=True)
prompt_tokens = len(setup_prompt.split())
response = ' '.join(question_to_claims.split()[prompt_tokens:])
#print(f"Test this:{question_to_claims.split()[prompt_tokens]}")
#return response
return question_to_claims.split()[prompt_tokens]
def naive_prompt(context,hypothesis):
prompt = f'''
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant to do Natural Language Inference.\nYou are given a CONTEXT and HYPOTHESIS and you will predict a label ONLY from the set ENTAILMENT,CONTRADICTION,NEUTRAL.
<|start_header_id|>user<|end_header_id|>
CONTEXT: {context}
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
HYPOTHESIS: {hypothesis}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>#Answer:
'''
return prompt
def naive_prompt_more_hc_prime(context,hypothesis):
prompt = f'''
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant to do Natural Language Inference.\nYou are given a CONTEXT and HYPOTHESIS and you will predict a label ONLY from the set ENTAILMENT,CONTRADICTION,NEUTRAL.
<|start_header_id|>user<|end_header_id|>
CONTEXT: The girl is wearing shoes
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
HYPOTHESIS: A girl asleep on a hard wood floor cuddling her baby doll
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>#Answer:
"neutral"
<|eot_id|>
CONTEXT: The girl is watching TV
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
HYPOTHESIS: A girl sleeping on the floor with her dolls
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>#Answer:
"neutral"
<|eot_id|>
CONTEXT: {context}
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
HYPOTHESIS: {hypothesis}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>#Answer:
'''
return prompt
def read_data(file_path):
file_content = []
with open(file_path,'r') as file:
for line in file:
line = json.loads(line)
sentence1 = line['sentence1']
sentence2 = line['sentence2']
gold_label = line['gold_label']
json_object = {
'sentence1':sentence1,
'sentence2':sentence2,
'label':gold_label
}
file_content.append(json_object)
return file_content
def llama3_snli():
#device_map = "mps"
device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
"/Users/sbhar/Riju/PhDCode/RAG_LLama/nebula-rag-code/llama3-8B-Instruct-hf",
return_dict=True,
torch_dtype=torch.float16,
device_map=device_map,
)
print("BaseLine Model Loaded !!")
print("-------------------------------------")
model = PeftModel.from_pretrained(model, "/Users/sbhar/Riju/PhDCode/SNLI-FT/model/snli-adapter", device_map=device_map)
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained("/Users/sbhar/Riju/PhDCode/RAG_LLama/nebula-rag-code/llama3-8B-Instruct-hf", use_fast=True,trust_remote_code=True)
tokenizer.pad_token_id = 18610
tokenizer.padding_side = "right"
print("Fine tuned Model and tokenizer Loaded Locally !!")
file_content = read_data('/Users/sbhar/Riju/PhDCode/SNLI-FT/data/snli_1.0/snli_1.0_test.jsonl')
pred_label_file = '/Users/sbhar/Riju/PhDCode/SNLI-FT/output/pred_outputs.json'
pred_outputs = {}
#print(file_content[0])
for i,item in enumerate(tqdm(file_content,desc="Predicting Labels")):
context = item['sentence1']
hypothesis = item['sentence2']
prompt = naive_prompt(context=context,hypothesis=hypothesis)
pred_label = generate_response(model=model,tokenizer=tokenizer,setup_prompt=prompt)
pred_outputs[i] = pred_label
with open(pred_label_file,'w') as file:
json.dump(pred_outputs,file)
print("All Predictions Dumped !!")
"""
context = file_content[500]['sentence1']
hypothesis = file_content[500]['sentence2']
print(context)
print(hypothesis)
print(file_content[500]['label'])
prompt = naive_prompt(context,hypothesis)
answer = generate_response(model=model,tokenizer=tokenizer,setup_prompt=prompt)
print(answer)
"""
if __name__ == "__main__":
llama3_snli()