|
|
|
import os |
|
from typing import List, Optional, Tuple |
|
|
|
import json |
|
|
|
from swift.llm import Messages |
|
from swift.llm.template.utils import ContextType |
|
from .utils import calculate_loss_scale |
|
|
|
|
|
class LossScale: |
|
loss_scale_config = None |
|
|
|
def __init__(self): |
|
if self.loss_scale_config is not None: |
|
path = os.path.dirname(os.path.abspath(__file__)) |
|
config_path = os.path.join(path, 'config', self.loss_scale_config) |
|
with open(config_path, 'r', encoding='utf-8') as json_file: |
|
self.loss_scale_map = json.load(json_file) |
|
else: |
|
self.loss_scale_map = None |
|
|
|
def get_loss_scale(self, |
|
context: str, |
|
context_type: ContextType, |
|
is_last_round: bool, |
|
*, |
|
query: Optional[str] = None) -> Tuple[List[str], List[float]]: |
|
"""Calculate loss scale |
|
|
|
Args: |
|
context: The input context |
|
context_type: The type of this context, like response/suffix(eos token)/other(query/system, etc.) |
|
is_last_round: If this is the last round of messages. |
|
query: The query of this round. |
|
|
|
Returns: |
|
A tuple, list of context and list of loss_scales |
|
""" |
|
if context_type in {ContextType.RESPONSE, ContextType.SUFFIX}: |
|
loss_scale = 1. |
|
else: |
|
loss_scale = 0. |
|
return [context], [loss_scale] |
|
|
|
def __call__(self, context_list: List[str], context_types: List[ContextType], messages: Messages, |
|
**kwargs) -> Tuple[List[str], List[float]]: |
|
res_context_list = [] |
|
res_loss_scale = [] |
|
i = 0 |
|
n_round = len(messages) // 2 |
|
for context, context_type in zip(context_list, context_types): |
|
is_last_round = i + 1 == n_round |
|
if context_type == ContextType.RESPONSE: |
|
query = messages[2 * i]['content'] |
|
assert context == messages[2 * i + 1]['content'] |
|
kwargs = {'query': query} |
|
i += 1 |
|
new_context, loss_scale = self.get_loss_scale(context, context_type, is_last_round, **kwargs) |
|
res_context_list += new_context |
|
res_loss_scale += loss_scale |
|
return res_context_list, res_loss_scale |
|
|
|
|
|
class LastRoundLossScale(LossScale): |
|
|
|
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs): |
|
if context_type == ContextType.RESPONSE: |
|
return [context], [float(is_last_round)] |
|
return super().get_loss_scale(context, context_type, is_last_round) |
|
|
|
|
|
class AgentFlanLossScale(LossScale): |
|
loss_scale_config = 'agentflan.json' |
|
|
|
def get_loss_scale(self, |
|
context: str, |
|
context_type: ContextType, |
|
is_last_round: bool, |
|
*, |
|
query: Optional[str] = None): |
|
if context_type == ContextType.RESPONSE: |
|
return calculate_loss_scale(query, context, self.loss_scale_map['response'], self.loss_scale_map['query']) |
|
return super().get_loss_scale(context, context_type, is_last_round) |
|
|
|
|
|
class REACTLossScale(LossScale): |
|
loss_scale_config = 'react.json' |
|
|
|
def get_loss_scale(self, |
|
context: str, |
|
context_type: ContextType, |
|
is_last_round: bool, |
|
*, |
|
query: Optional[str] = None): |
|
if context_type == ContextType.RESPONSE: |
|
return calculate_loss_scale(query, context, self.loss_scale_map) |
|
return super().get_loss_scale(context, context_type, is_last_round) |
|
|
|
|
|
class QwenLossScale(REACTLossScale): |
|
loss_scale_config = 'qwen.json' |
|
|
|
|
|
class HermesLossScale(REACTLossScale): |
|
loss_scale_config = 'hermes.json' |
|
|
|
|
|
class AlphaUmiLossScale(REACTLossScale): |
|
loss_scale_config = 'alpha_umi.json' |
|
|
|
|
|
class TrainAllLossScale(LossScale): |
|
|
|
def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs): |
|
return [context], [1.] |
|
|
|
|
|
class IgnoreEmptyThink(REACTLossScale): |
|
loss_scale_config = 'ignore_empty_think.json' |
|
|
|
|
|
|
|
loss_scale_map = { |
|
'last_round': LastRoundLossScale(), |
|
'default': LossScale(), |
|
'all': TrainAllLossScale(), |
|
'ignore_empty_think': IgnoreEmptyThink(), |
|
|
|
'react': REACTLossScale(), |
|
'hermes': HermesLossScale(), |
|
'qwen': QwenLossScale(), |
|
'agentflan': AgentFlanLossScale(), |
|
'alpha_umi': AlphaUmiLossScale(), |
|
} |
|
|