demo / llama2.py
mft-moral's picture
Upload 17 files
6057964 verified
############################
# LLaMA 2 7b chat
############################
import os
import utils
import torch
import json
import argparse
import pandas as pd
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
delimiter = '####'
system_message_for = f'''The Moral Foundations Theory introduces six foundations of morality:\
1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \
You are given a sentence delimited with {delimiter} characters.
Your task is to generate a one-sentence-long persuasive argument for \
why the given sentence is sexist based on moral foundation theory.
Important: Start the argument with ```This sentence is sexist because it violates the moral foundations of```"
'''
system_message_against = f'''The Moral Foundations Theory introduces six foundations of morality:\
1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \
You are given a sentence delimited with {delimiter} characters.
Your task is to generate a one-sentence-long persuasive argument for \
why the given sentence is not sexist based on moral foundation theory.
Important: Start the argument with ```This sentence is not sexist because it aligns with moral values of```
'''
def user_message(text):
return f'''{delimiter}{text}{delimiter}
'''
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="llama2")
parser.add_argument('--cuda', type=str, default="-1")
parser.add_argument('--idx', type=int, default=-1)
args = parser.parse_args()
print(args)
if args.idx == "-1":
start_index = 0
end_index = utils.MAX_ID
else:
start_index = args.idx
end_index = utils.MAX_ID
if args.cuda != "-1":
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
folder_path = './generations/llama2/'
if not os.path.exists('./generations'):
os.makedirs('./generations')
if not os.path.exists('./generations/llama2'):
os.makedirs('./generations/llama2')
data = utils.read_implicit_edos()
model = utils.HFModel(
model_path='meta-llama/Llama-2-7b-chat-hf')
# arguments_for = []
# arguments_against = []
for i in range(len(data)):
if os.path.exists(folder_path + str(i) + '.json'):
continue
if i < start_index or i >= end_index:
continue
t = data['text'][i]
response_for = model.generate(system_message_for + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True)
response_against = model.generate(system_message_against + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True)
# print(response_for)
# print("**********")
# print(response_against)
# print("**********")
response_for = response_for.split(delimiter)[-1].strip()
response_against = response_against.split(delimiter)[-1].strip()
print(i)
for _ in range(5):
if response_for.startswith('This sentence is sexist'):
break
print("re-generating...")
response_for = model.generate(system_message_for + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True)
response_for = response_for.split(delimiter)[-1].strip()
for _ in range(5):
if response_against.startswith('This sentence is not sexist'):
break
print("re-generating...")
response_against = model.generate(system_message_against + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True)
response_against = response_against.split(delimiter)[-1].strip()
# print(response_for)
# print("**********")
# print(response_against)
# print("**********")
# arguments_for.append(response_for)
# arguments_against.append(response_against)
with open(folder_path + str(i) + '.json', 'w') as f:
json.dump({'id': str(data['id'][i]), 'for': response_for, 'against': response_against}, f)
# if i > 7:
# break
# print(len(arguments_for), len(arguments_against))
# arguments_csv = pd.DataFrame(columns=['for', 'against'])
# arguments_csv['for'] = arguments_for
# arguments_csv['against'] = arguments_against
# arguments_csv.to_csv('./generations/llama2.csv')