--- license: apache-2.0 datasets: - emozilla/booksum-summary-analysis_gptneox-8192 - kmfoda/booksum --- # mpt-7b-storysummarizer This is a fine-tuned version of [mosaicml/mpt-7b-storywriter](https://huggingface.co/mosaicml/mpt-7b-storywriter) on [emozilla/booksum-summary-analysis_gptneox-8192](emozilla/booksum-summary-analysis_gptneox-8192), which is adapted from [kmfoda/booksum](https://huggingface.co/datasets/kmfoda/booksum). The training run was performed using [llm-foundry](https://github.com/mosaicml/llm-foundry) on an 8xA100 80 GB node at 8192 context length. The run can be viewed on [wandb](https://wandb.ai/emozilla/booksum/runs/457ym4r9). ## How to Use This model is intended for summarization and literary analysis of fiction stories. It can be prompted in one of two ways: ``` SOME_FICTION ### SUMMARY: ``` or ``` SOME_FICTION ### ANALYSIS: ``` A `repetition_penalty` of ~1.04 seems to be best. For summary prompts, simple greedy search suffices while a temperature of 0.8 works well for analysis. The model often prints `'#'` to delinate the end of a a summary or analyis. You can use `transformers.StopOnTokens` to end a generation. ```python class StopOnTokens(StoppingCriteria): def __init__(self, stop_ids): self.stop_ids = stop_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in self.stop_ids: if input_ids[0][-1] == stop_id: return True return False stop_ids = tokenizer("#").input_ids stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]), ``` Pass `stopping_criteria` as an argument to the model's `generate` function to stop on `#`. The code for this model includes adaptions from [Birchlabs/mosaicml-mpt-7b-chat-qlora](https://huggingface.co/Birchlabs/mosaicml-mpt-7b-chat-qlora) which allow MPT models to be loaded with `device_map="auto"` and `load_in_8bit=True`. For longer contexts, the following is recommended: ```python tokenizer = AutoTokenizer.from_pretrained("emozilla/mpt-7b-storysummarizer") model = AutoModelForCausalLM.from_pretrained( "emozilla/mpt-7b-storysummarizer", load_in_8bit=True, trust_remote_code=True, device_map="auto") ```