Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
raw
history blame
4.7 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List, Optional, Tuple
import json
from swift.llm import Messages
from swift.llm.template.utils import ContextType
from .utils import calculate_loss_scale
class LossScale:
loss_scale_config = None # path
def __init__(self):
if self.loss_scale_config is not None:
path = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(path, 'config', self.loss_scale_config)
with open(config_path, 'r', encoding='utf-8') as json_file:
self.loss_scale_map = json.load(json_file)
else:
self.loss_scale_map = None
def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None) -> Tuple[List[str], List[float]]:
"""Calculate loss scale
Args:
context: The input context
context_type: The type of this context, like response/suffix(eos token)/other(query/system, etc.)
is_last_round: If this is the last round of messages.
query: The query of this round.
Returns:
A tuple, list of context and list of loss_scales
"""
if context_type in {ContextType.RESPONSE, ContextType.SUFFIX}:
loss_scale = 1.
else:
loss_scale = 0.
return [context], [loss_scale]
def __call__(self, context_list: List[str], context_types: List[ContextType], messages: Messages,
**kwargs) -> Tuple[List[str], List[float]]:
res_context_list = []
res_loss_scale = []
i = 0
n_round = len(messages) // 2
for context, context_type in zip(context_list, context_types):
is_last_round = i + 1 == n_round
if context_type == ContextType.RESPONSE:
query = messages[2 * i]['content']
assert context == messages[2 * i + 1]['content']
kwargs = {'query': query}
i += 1
new_context, loss_scale = self.get_loss_scale(context, context_type, is_last_round, **kwargs)
res_context_list += new_context
res_loss_scale += loss_scale
return res_context_list, res_loss_scale
class LastRoundLossScale(LossScale):
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
if context_type == ContextType.RESPONSE:
return [context], [float(is_last_round)]
return super().get_loss_scale(context, context_type, is_last_round)
class AgentFlanLossScale(LossScale):
loss_scale_config = 'agentflan.json'
def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None):
if context_type == ContextType.RESPONSE:
return calculate_loss_scale(query, context, self.loss_scale_map['response'], self.loss_scale_map['query'])
return super().get_loss_scale(context, context_type, is_last_round)
class REACTLossScale(LossScale):
loss_scale_config = 'react.json'
def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None):
if context_type == ContextType.RESPONSE:
return calculate_loss_scale(query, context, self.loss_scale_map)
return super().get_loss_scale(context, context_type, is_last_round)
class QwenLossScale(REACTLossScale):
loss_scale_config = 'qwen.json'
class HermesLossScale(REACTLossScale):
loss_scale_config = 'hermes.json'
class AlphaUmiLossScale(REACTLossScale):
loss_scale_config = 'alpha_umi.json'
class TrainAllLossScale(LossScale):
def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):
return [context], [1.]
class IgnoreEmptyThink(REACTLossScale):
loss_scale_config = 'ignore_empty_think.json'
# Add your loss scale here, use --loss_scale xxx to train
loss_scale_map = {
'last_round': LastRoundLossScale(),
'default': LossScale(),
'all': TrainAllLossScale(),
'ignore_empty_think': IgnoreEmptyThink(),
# agent
'react': REACTLossScale(),
'hermes': HermesLossScale(),
'qwen': QwenLossScale(),
'agentflan': AgentFlanLossScale(),
'alpha_umi': AlphaUmiLossScale(),
}