emozilla's picture
Create README.md
ffbd412
|
raw
history blame
2.24 kB
metadata
license: apache-2.0
datasets:
  - emozilla/booksum-summary-analysis_gptneox-8192
  - kmfoda/booksum

mpt-7b-storysummarizer

This is a fine-tuned version of mosaicml/mpt-7b-storywriter on emozilla/booksum-summary-analysis_gptneox-8192, which is adapted from kmfoda/booksum. The training run was performed using llm-foundry on an 8xA100 80 GB node at 8192 context length. The run can be viewed on wandb.

How to Use

This model is intended for summarization and literary analysis of fiction stories. It can be prompted in one of two ways:

SOME_FICTION

### SUMMARY:

or

SOME_FICTION

### ANALYSIS:

A repetition_penalty of ~1.04 seems to be best. For summary prompts, simple greedy search suffices while a temperature of 0.8 works well for analysis. The model often prints '#' to delinate the end of a a summary or analyis. You can use transformers.StopOnTokens to end a generation.

class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_ids):
        self.stop_ids = stop_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in self.stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

stop_ids = tokenizer("#").input_ids
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]),

Pass stopping_criteria as an argument to the model's generate function to stop on #.

The code for this model includes adaptions from Birchlabs/mosaicml-mpt-7b-chat-qlora which allow MPT models to be loaded with device_map="auto" and load_in_8bit=True. For longer contexts, the following is recommended:

tokenizer = AutoTokenizer.from_pretrained("emozilla/mpt-7b-storysummarizer")
model = AutoModelForCausalLM.from_pretrained(
  "emozilla/mpt-7b-storysummarizer",
  load_in_8bit=True,
  trust_remote_code=True,
  device_map="auto")