File size: 3,610 Bytes
6057964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
############################
# Vicuna 13b v1.5
############################

import os
import utils
import torch
import json
import argparse
import pandas as pd
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"


delimiter = '####'
system_message_for = f'''The Moral Foundations Theory introduces six foundations of morality:\
1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \
You are given a sentence delimited with {delimiter} characters. \
Your task is to generate a one-sentence-long persuasive argument for why the given sentence is sexist based on moral foundation theory. \
Make sure to justify your argument. \
Start the argument with ```This sentence is sexist because it violates the moral foundations of```"

'''
system_message_against = f'''The Moral Foundations Theory introduces six foundations of morality:\
1)Care, 2)Equality, 3) Proportionality, 4)Loyalty, 5)Authority, 6)Purity. \
You are given a sentence delimited with {delimiter} characters. \
Your task is to generate a one-sentence-long persuasive argument for why the given sentence is not sexist based on moral foundation theory. \
Make sure the justify your argument. \
Start the argument with ```This sentence is not sexist because it aligns with moral values of```"

'''

def user_message(text):
    return f'''{delimiter}{text}{delimiter}'''



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="vicuna")
    parser.add_argument('--cuda', type=str, default="-1")
    parser.add_argument('--idx', type=int, default=-1)
    args = parser.parse_args()
    print(args)

    if args.idx == "-1":
        start_index = 0
        end_index = utils.MAX_ID
    else:
        start_index = args.idx
        end_index = args.idx + 750
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)


    folder_path = './generations/vicuna/'
    if not os.path.exists('./generations'):
        os.makedirs('./generations')
    if not os.path.exists('./generations/vicuna'):
        os.makedirs('./generations/vicuna')


    data = utils.read_implicit_edos()


    model = utils.HFModel(
        model_path='lmsys/vicuna-13b-v1.5')
    

    # arguments_for = []
    # arguments_against = []
    for i in range(len(data)):
        if os.path.exists(folder_path + str(i) + '.json'):
            continue
        if i < start_index or i >= end_index:
            continue
        t = data['text'][i]
        response_for = model.generate(system_message_for + user_message(t), max_new_tokens=400)
        response_against = model.generate(system_message_against + user_message(t), max_new_tokens=400)
        # print(response_for)
        # print("**********")
        # print(response_against)
        # print("**********")
        response_for = response_for.split(delimiter)[-1].strip()
        response_against = response_against.split(delimiter)[-1].strip()
        print(i)
        # print(response_for)
        # print("**********")
        # print(response_against)
        # print("**********")
        # arguments_for.append(response_for)
        # arguments_against.append(response_against)
        with open(folder_path + str(i) + '.json', 'w') as f:
            json.dump({'id': str(data['id'][i]), 'for': response_for, 'against': response_against}, f)
        # if i > 7:
        #     break
    
    # print(len(arguments_for), len(arguments_against))
    # arguments_csv = pd.DataFrame(columns=['for', 'against'])
    # arguments_csv['for'] = arguments_for
    # arguments_csv['against'] = arguments_against
    # arguments_csv.to_csv('./generations/vicuna_v1.csv')