File size: 3,948 Bytes
82b9374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import jinja2
from flows.base_flows import AtomicFlow
from flows.utils import logging
from flows.utils import general_helpers
from typing import Dict,Any,Optional,List
from flows.prompt_template import JinjaPrompt
from copy import deepcopy
import numpy as np
import os
import hydra
log = logging.get_logger(__name__)

class DemonstrationsAtomicFlow(AtomicFlow):
    demonstrations_k: Optional[int] = None
    query_prompt_template: JinjaPrompt
    response_prompt_template: JinjaPrompt
    params: Dict
    
    def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs):
        super().__init__(**kwargs)
        self.params = params
        self.data = data
        self.demonstrations_k = self.params.get("demonstrations_k", None)
        
        #typically the query would be what the user (human) asks the assistant (LLM) 
        self.query_prompt_template = query_prompt_template
        #typically the response would be what the assistant (LLM) should answer to the user (human) 
        self.response_prompt_template = response_prompt_template
        if self.data is None:
            self._load_data()
            
    @classmethod
    def _set_up_prompts(cls, config):
        kwargs = {}
        kwargs["query_prompt_template"] = \
            hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial")
        kwargs["response_prompt_template"] = \
            hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial")
        return kwargs
        
    @classmethod
    def instantiate_from_config(cls, config):
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))
        kwargs.update({"params": flow_config["params"]})
        kwargs.update({"data": flow_config["data"]})
        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)
            
    def _get_query_message_content(self, sample_data: Dict):
        input_variables = self.query_prompt_template.input_variables
        return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables})

    def _get_response_message_content(self, sample_data: Dict):
        input_variables = self.response_prompt_template.input_variables
        return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
        
    def _get_io_pair(self, idx):
        dp = self.data[idx]
        
        query_data = dp["query_data"]
        response_data = dp["response_data"]
        
        query = self._get_query_message_content(query_data)
        response = self._get_response_message_content(response_data)
        
        return {"idx": idx, "query": query,"response": response}
            
    def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]:
        demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data)
        io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)]
        return io_pairs
        
    def _load_data(self):
        demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl")
        self.data = general_helpers.read_jsonlines(demonstrations_file)
        
        if self.params.get("ids_to_keep", False):
            if isinstance(self.params["ids_to_keep"], str):
                ids_to_keep = set(self.params["ids_to_keep"].split(","))
            else:
                ids_to_keep = set(self.params["ids_to_keep"])

            self.data = [d for d in self.data if d["id"] in ids_to_keep]

        log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"])

    def run(self,
            input_data: Dict[str, Any]) -> Dict[str, Any]:
        return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}}