File size: 4,695 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List, Optional, Tuple
import json
from swift.llm import Messages
from swift.llm.template.utils import ContextType
from .utils import calculate_loss_scale
class LossScale:
loss_scale_config = None # path
def __init__(self):
if self.loss_scale_config is not None:
path = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(path, 'config', self.loss_scale_config)
with open(config_path, 'r', encoding='utf-8') as json_file:
self.loss_scale_map = json.load(json_file)
else:
self.loss_scale_map = None
def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None) -> Tuple[List[str], List[float]]:
"""Calculate loss scale
Args:
context: The input context
context_type: The type of this context, like response/suffix(eos token)/other(query/system, etc.)
is_last_round: If this is the last round of messages.
query: The query of this round.
Returns:
A tuple, list of context and list of loss_scales
"""
if context_type in {ContextType.RESPONSE, ContextType.SUFFIX}:
loss_scale = 1.
else:
loss_scale = 0.
return [context], [loss_scale]
def __call__(self, context_list: List[str], context_types: List[ContextType], messages: Messages,
**kwargs) -> Tuple[List[str], List[float]]:
res_context_list = []
res_loss_scale = []
i = 0
n_round = len(messages) // 2
for context, context_type in zip(context_list, context_types):
is_last_round = i + 1 == n_round
if context_type == ContextType.RESPONSE:
query = messages[2 * i]['content']
assert context == messages[2 * i + 1]['content']
kwargs = {'query': query}
i += 1
new_context, loss_scale = self.get_loss_scale(context, context_type, is_last_round, **kwargs)
res_context_list += new_context
res_loss_scale += loss_scale
return res_context_list, res_loss_scale
class LastRoundLossScale(LossScale):
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
if context_type == ContextType.RESPONSE:
return [context], [float(is_last_round)]
return super().get_loss_scale(context, context_type, is_last_round)
class AgentFlanLossScale(LossScale):
loss_scale_config = 'agentflan.json'
def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None):
if context_type == ContextType.RESPONSE:
return calculate_loss_scale(query, context, self.loss_scale_map['response'], self.loss_scale_map['query'])
return super().get_loss_scale(context, context_type, is_last_round)
class REACTLossScale(LossScale):
loss_scale_config = 'react.json'
def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None):
if context_type == ContextType.RESPONSE:
return calculate_loss_scale(query, context, self.loss_scale_map)
return super().get_loss_scale(context, context_type, is_last_round)
class QwenLossScale(REACTLossScale):
loss_scale_config = 'qwen.json'
class HermesLossScale(REACTLossScale):
loss_scale_config = 'hermes.json'
class AlphaUmiLossScale(REACTLossScale):
loss_scale_config = 'alpha_umi.json'
class TrainAllLossScale(LossScale):
def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):
return [context], [1.]
class IgnoreEmptyThink(REACTLossScale):
loss_scale_config = 'ignore_empty_think.json'
# Add your loss scale here, use --loss_scale xxx to train
loss_scale_map = {
'last_round': LastRoundLossScale(),
'default': LossScale(),
'all': TrainAllLossScale(),
'ignore_empty_think': IgnoreEmptyThink(),
# agent
'react': REACTLossScale(),
'hermes': HermesLossScale(),
'qwen': QwenLossScale(),
'agentflan': AgentFlanLossScale(),
'alpha_umi': AlphaUmiLossScale(),
}
|