from dataclasses import dataclass from typing import Dict, Any, Union from constants import ( FIM_MIDDLE, FIM_PREFIX, FIM_SUFFIX, MIN_TEMPERATURE, ) from settings import ( FIM_INDICATOR, ) @dataclass class StarCoderRequestConfig: temperature: float max_new_tokens: int top_p: float repetition_penalty: float version: str def __post_init__(self): self.temperature = min(float(self.temperature), MIN_TEMPERATURE) self.max_new_tokens = int(self.max_new_tokens) self.top_p = float(self.top_p) self.repetition_penalty = float(self.repetition_penalty) self.do_sample = True self.seed = 42 def __repr__(self) -> str: """Returns a custom string representation of the Configurations.""" values = dict( model = self.version, temp = self.temperature, tokens = self.max_new_tokens, p = self.top_p, penalty = self.repetition_penalty, sample = self.do_sample, seed = self.seed, ) return f"StarCoderRequestConfig({values})" def kwargs(self) -> Dict[str, Union[Any, float, int]]: """ Returns a custom dictionary representation of the Configurations. removing the model version. """ values = vars(self).copy() values.pop("version") return values @dataclass class StarCoderRequest: prompt: str settings: StarCoderRequestConfig def __post_init__(self): self.fim_mode = FIM_INDICATOR in self.prompt self.prefix, self.suffix = None, None if self.fim_mode: try: self.prefix, self.suffix = self.prompt.split(FIM_INDICATOR) except Exception as err: print(str(err)) raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!") from err self.prompt = f"{FIM_PREFIX}{self.prefix}{FIM_SUFFIX}{self.suffix}{FIM_MIDDLE}" def __repr__(self) -> str: """Returns a custom string representation of the Request.""" values = dict( prompt = self.prompt, configuration = self.settings, ) return f"StarCoderRequest({values})"