|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import utils |
|
|
import torch |
|
|
import json |
|
|
import argparse |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" |
|
|
|
|
|
|
|
|
delimiter = '####' |
|
|
system_message_for = f'''The Moral Foundations Theory introduces six foundations of morality:\ |
|
|
1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \ |
|
|
You are given a sentence delimited with {delimiter} characters. |
|
|
Your task is to generate a one-sentence-long persuasive argument for \ |
|
|
why the given sentence is sexist based on moral foundation theory. |
|
|
Important: Start the argument with ```This sentence is sexist because it violates the moral foundations of```" |
|
|
|
|
|
''' |
|
|
system_message_against = f'''The Moral Foundations Theory introduces six foundations of morality:\ |
|
|
1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \ |
|
|
You are given a sentence delimited with {delimiter} characters. |
|
|
Your task is to generate a one-sentence-long persuasive argument for \ |
|
|
why the given sentence is not sexist based on moral foundation theory. |
|
|
Important: Start the argument with ```This sentence is not sexist because it aligns with moral values of``` |
|
|
|
|
|
''' |
|
|
|
|
|
def user_message(text): |
|
|
return f'''{delimiter}{text}{delimiter} |
|
|
''' |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="llama2") |
|
|
parser.add_argument('--cuda', type=str, default="-1") |
|
|
parser.add_argument('--idx', type=int, default=-1) |
|
|
args = parser.parse_args() |
|
|
print(args) |
|
|
|
|
|
if args.idx == "-1": |
|
|
start_index = 0 |
|
|
end_index = utils.MAX_ID |
|
|
else: |
|
|
start_index = args.idx |
|
|
end_index = utils.MAX_ID |
|
|
|
|
|
if args.cuda != "-1": |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda) |
|
|
|
|
|
|
|
|
folder_path = './generations/llama2/' |
|
|
if not os.path.exists('./generations'): |
|
|
os.makedirs('./generations') |
|
|
if not os.path.exists('./generations/llama2'): |
|
|
os.makedirs('./generations/llama2') |
|
|
|
|
|
|
|
|
data = utils.read_implicit_edos() |
|
|
|
|
|
|
|
|
model = utils.HFModel( |
|
|
model_path='meta-llama/Llama-2-7b-chat-hf') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(data)): |
|
|
if os.path.exists(folder_path + str(i) + '.json'): |
|
|
continue |
|
|
if i < start_index or i >= end_index: |
|
|
continue |
|
|
t = data['text'][i] |
|
|
response_for = model.generate(system_message_for + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) |
|
|
response_against = model.generate(system_message_against + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response_for = response_for.split(delimiter)[-1].strip() |
|
|
response_against = response_against.split(delimiter)[-1].strip() |
|
|
print(i) |
|
|
for _ in range(5): |
|
|
if response_for.startswith('This sentence is sexist'): |
|
|
break |
|
|
print("re-generating...") |
|
|
response_for = model.generate(system_message_for + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) |
|
|
response_for = response_for.split(delimiter)[-1].strip() |
|
|
for _ in range(5): |
|
|
if response_against.startswith('This sentence is not sexist'): |
|
|
break |
|
|
print("re-generating...") |
|
|
response_against = model.generate(system_message_against + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) |
|
|
response_against = response_against.split(delimiter)[-1].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(folder_path + str(i) + '.json', 'w') as f: |
|
|
json.dump({'id': str(data['id'][i]), 'for': response_for, 'against': response_against}, f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|