File size: 6,260 Bytes
96b6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from citekit.cite_modules.LLM import LLM
from citekit.cite_modules.augment_model import CitationSimplyfier
from citekit.pipeline.pipeline import Pipeline, PIPELINE_OUTPUT,PIPELINE_DOC_CACHE
from citekit.prompt.prompt import Prompt, ALCEDocPrompt
from citekit.Dataset.Dataset import PromptDataset
from citekit.evaluator.evaluator import DefaultEvaluator
from parser import *
import argparse
import json
from citekit.utils.utils import cut_and_make_as,one_paragraph,make_as
from nltk import sent_tokenize


def sentences_as(datakey):
    def f(passage):
        print( sent_tokenize(passage))
        return [{datakey:one_paragraph(s)} for s in sent_tokenize(passage)]
    return f

PARA_SEP = '\n\n'
if __name__ == '__main__':

    # SETTING ARGS
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_path", type=str, default='res.json', help="Path to the config file")
    parser.add_argument("--model", type=str, default='gpt-3.5-turbo', help="model name or path")
    parser.add_argument("--shots", type=int, default=1, help="number of shots")
    parser.add_argument("--ndoc", type=int, default=3, help="number of docs")
    parser.add_argument("--pr", action='store_true', help="use cite PR")
    parser.add_argument("--rouge", action='store_true', help="use rouge")
    parser.add_argument("--temp", type=float, default=0.5, help="temperature")
    parser.add_argument("--qa", action='store_true', help="eval qa")
    parser.add_argument("--str_em", action='store_true', help="eval str_em")
    parser.add_argument("--mauve",  action='store_true', help="eval mauve")
    parser.add_argument("--length", type=bool, default=True, help="eval length")
    parser.add_argument("--claims", action='store_true', help="eval claims")
    parser.add_argument("--qampari", type=str, default=False, help="eval qampari")
    parser.add_argument("--turns", type=int, default=1, help="k")
    parser.add_argument("--use_fast_pr", type=str, default=False, help="test")
    parser.add_argument("--dataset", type=str, default='data/asqa_eval_gtr_top100.json', help="dataset")
    parser.add_argument("--demo", type=str, default='prompts/AnG.json', help="demo")
    parser.add_argument("--mode", type=str, default='plan', help="mode: AnG or plan")
    parser.add_argument("--data_num", type=int, default=200, help="k")
    args = parser.parse_args()

    # DATA LOADING
    file_path = args.dataset
    demo_path = args.demo
    with open(file_path,'r',encoding='utf-8') as file:
        dataset = json.load(file)
    with open(demo_path,'r',encoding='utf-8') as file:
        demo = json.load(file)[args.mode]

    answer_inst = 'Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly, by answering the subquestion. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
    revise_instruction = '''You will be provided with a question, some documents and an answer. Your task is to correct possible errors in the answer and made it more coherent and fluent. Only make changes that are necessary. Keep the citations.'''
    dataset = PromptDataset(dataset,'question','answer','answers','qa_pairs','claims', docs = lambda data: ALCEDocPrompt().default_load_data_wo_title(data['docs'][:args.ndoc]))[:args.data_num]
    if args.mode == 'AnG':
        gen_shot = demo['gen_instruction'] + PARA_SEP + demo['gen_shot'] + PARA_SEP
        answer_ppt = {'INST':demo['gen_instruction'],'shot':gen_shot, 'add':'The next sentence is:'}
    elif args.mode == 'plan':
        shot = demo['shot1'] + demo['shot2']
        self_ppt = {'INST':demo['INST'],'shot':shot, 'add':'subquestions: \n'}
        answer_shot = demo['answer_shot_1'] + demo['answer_shot_2']
        answer_ppt = {'INST':answer_inst,'shot':answer_shot,'add':''}

    prompt = Prompt(template='<shot><INST><question><docs><prefix><sub><span><add><new_ans>',
                    components={'INST':'{INST}\n\n', 
                                'shot':'{shot}',
                                'question':'Question:{question}\n\n',
                                'docs':'{docs}\n',
                                'span':'The highlighted spans are: \n{span}\n\n',
                                'prefix':'Prefix: {prefix}\n\n',
                                'sub':'subquestions: \n{sub}\n\n',
                                'add':'Answer: \n{add}',
                                'new_ans': "\nNew Answer: {new_ans}"
                                })
    
    plan_prompt = Prompt(template='<shot><INST><question><docs><sub><add>',
                    components={'INST':'{INST}\n\n', 
                                'shot':'{shot}',
                                'question':'Question:{question}\n\n',
                                'docs':'{docs}\n',
                                'sub':'subquestions: \n{sub}\n\n',
                                'add':'{add}'})

    # PIPELINE
    evaluator = DefaultEvaluator(args)

    attribute = LLM(model = args.model, prompt_maker = plan_prompt, self_prompt=self_ppt, post_processing=sentences_as('sub'))
    reviser = LLM(model=args.model, prompt_maker=prompt, self_prompt={'INST':revise_instruction, 'new_ans': ''}, max_turn=args.turns)
    answer = LLM(model = args.model, prompt_maker = prompt, self_prompt=answer_ppt, share_model_with=attribute.get_first_module(), merge=True, parallel=True)

    # No citation -> simply combine
    # simplifier
    answer.set_target(reviser, post_processing=lambda output: {'add': output})
    attribute.set_target(answer,post_processing=sentences_as('sub'))
    pipeline = Pipeline(save_path=args.save_path, llm = answer, module = [attribute, reviser], evaluator = evaluator, dataset = dataset)
    pipeline.set_initial_module(module=attribute)
    answer.set_output(post_processing=lambda ls: ''.join(map(one_paragraph,ls)))




    #pipeline.run_on_dataset(datakeys=['question','docs'], init_docs='docs', initial_module = attribute)