############################ # LLaMA 2 7b chat ############################ import os import utils import torch import json import argparse import pandas as pd from tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" delimiter = '####' system_message_for = f'''The Moral Foundations Theory introduces six foundations of morality:\ 1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \ You are given a sentence delimited with {delimiter} characters. Your task is to generate a one-sentence-long persuasive argument for \ why the given sentence is sexist based on moral foundation theory. Important: Start the argument with ```This sentence is sexist because it violates the moral foundations of```" ''' system_message_against = f'''The Moral Foundations Theory introduces six foundations of morality:\ 1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \ You are given a sentence delimited with {delimiter} characters. Your task is to generate a one-sentence-long persuasive argument for \ why the given sentence is not sexist based on moral foundation theory. Important: Start the argument with ```This sentence is not sexist because it aligns with moral values of``` ''' def user_message(text): return f'''{delimiter}{text}{delimiter} ''' if __name__ == "__main__": parser = argparse.ArgumentParser(description="llama2") parser.add_argument('--cuda', type=str, default="-1") parser.add_argument('--idx', type=int, default=-1) args = parser.parse_args() print(args) if args.idx == "-1": start_index = 0 end_index = utils.MAX_ID else: start_index = args.idx end_index = utils.MAX_ID if args.cuda != "-1": os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda) folder_path = './generations/llama2/' if not os.path.exists('./generations'): os.makedirs('./generations') if not os.path.exists('./generations/llama2'): os.makedirs('./generations/llama2') data = utils.read_implicit_edos() model = utils.HFModel( model_path='meta-llama/Llama-2-7b-chat-hf') # arguments_for = [] # arguments_against = [] for i in range(len(data)): if os.path.exists(folder_path + str(i) + '.json'): continue if i < start_index or i >= end_index: continue t = data['text'][i] response_for = model.generate(system_message_for + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) response_against = model.generate(system_message_against + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) # print(response_for) # print("**********") # print(response_against) # print("**********") response_for = response_for.split(delimiter)[-1].strip() response_against = response_against.split(delimiter)[-1].strip() print(i) for _ in range(5): if response_for.startswith('This sentence is sexist'): break print("re-generating...") response_for = model.generate(system_message_for + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) response_for = response_for.split(delimiter)[-1].strip() for _ in range(5): if response_against.startswith('This sentence is not sexist'): break print("re-generating...") response_against = model.generate(system_message_against + user_message(t), temperature=0.7, max_new_tokens=200, pad_token=True) response_against = response_against.split(delimiter)[-1].strip() # print(response_for) # print("**********") # print(response_against) # print("**********") # arguments_for.append(response_for) # arguments_against.append(response_against) with open(folder_path + str(i) + '.json', 'w') as f: json.dump({'id': str(data['id'][i]), 'for': response_for, 'against': response_against}, f) # if i > 7: # break # print(len(arguments_for), len(arguments_against)) # arguments_csv = pd.DataFrame(columns=['for', 'against']) # arguments_csv['for'] = arguments_for # arguments_csv['against'] = arguments_against # arguments_csv.to_csv('./generations/llama2.csv')