[Experiment] MPT 7B + LangChain Custom LLM + transformers.accelerator, on a POTATO

#16
by saber7ooth - opened

So, there's good news and bad news after this experiment full disclosure incoming:

  • Yes, you can run this on a personal computer.
  • No, it isn't fast, at all. 200 tokens took 3m 16.9s a prompt on my machine, and its only because I've got 32GB RAM thanks to studio work that I do. Ok maaaybe I cheated, this machine's built to run Unity, buuut...

My graphics card is barely sufficient (GTX1080 TI with 8GB VRAM). But thanks to llm_foundary GitHub repository, I was able to discern which blocks needed split-locking on the accelerator, and expand all that out manually and tie the weights so it works on a home PC with a crappy graphics card via: https://huggingface.co/docs/accelerate/usage_guides/big_modeling . It runs at about 1.09 it/s (a token a second on a home computer πŸ˜…πŸ”₯) -- My brave little toaster did it so yours can, too.

Be aware of these constraints when doing this indie. If you want the full 60000 max_new_tokens, prepare to come back in an hour if you're poor like me....

I hacked this together in a night, and I managed to get StoryWriter working in WSL on LangChain with the hf api on huggingface_hub with some (rudimentary) cache management for disk offloading in accelerators. Not all of the parameters from transformers are supported, only a few during my lab testing. You'll have to add that in yourself. Have fun adding em! -- I just didn't feel like stuffing in all the pydantic boiler plate for them, without calling upon HuggingFacePipeline from LangChain, which I'm trying to avoid because accelerators and weight tying (you could probably go as far as unwrapping these, too, if ya want for marginal performance gains)

This includes support for TextIteratorStreamer and will call the LLM runner call back every time a token is added (LangChain API stuff was a lot of pydantic boilerplate), I'm not going to guarantee thread safety. This is an experiment, like everything else in the ML world.

I had to run the code in my local jupyter notebook and restart the kernel often. It took a long time to get working correctly on LangChain LLMs.

Here's the code (for a custom LangChain LLM) that works on a home PC in WSL Ubuntu 22.04 with accelerators in a venv πŸ€—

from functools import partial
from typing import Any, Dict, List, Mapping, Optional, Set
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from accelerate import Accelerator, load_checkpoint_and_dispatch, init_empty_weights
from tqdm.auto import tqdm
from threading import Thread
from huggingface_hub import snapshot_download, cached_assets_path

"""Wrapper for the MosaicML MPT models."""
class MosaicML(LLM):
    model_name: str = Field("mosaicml/mpt-7b-storywriter", alias='model_name')
    """The name of the model to use."""

    tokenizer_name: str = Field("EleutherAI/gpt-neox-20b", alias='tokenizer_name')
    """The name of the sentence tokenizer to use."""

    config: Any = None #: :meta private:
    """The reference to the loaded configuration."""

    tokenizer: Any = None #: :meta private:
    """The reference to the loaded tokenizer."""

    model: Any = None #: :meta private:
    """The reference to the loaded model."""

    accelerator: Any = None #: :meta private:
    """The reference to the loaded hf device accelerator."""

    attn_impl: str = Field("torch", alias='attn_impl')
    """The attention implementation to use."""

    torch_dtype: Any = Field(torch.bfloat16, alias='torch_dtype')
    """The torch data type to use."""

    max_new_tokens: Optional[int] = Field(10000, alias='max_new_tokens')
    """The maximum number of tokens to generate."""

    do_sample: Optional[bool] = Field(True, alias='do_sample')
    """Whether to sample or not."""

    temperature: Optional[float] = Field(0.8, alias='temperature')
    """The temperature to use for sampling."""

    echo: Optional[bool] = Field(False, alias='echo')
    """Whether to echo the prompt."""
    
    stop: Optional[List[str]] = []
    """A list of strings to stop generation when encountered."""


    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid


    def _mpt_default_params(self) -> Dict[str, Any]:
        """Get the default parameters."""
        return {
            "max_new_tokens": self.max_new_tokens,
            "temperature": self.temperature,
            "do_sample": self.do_sample,
        }
    
    @staticmethod
    def _mpt_param_names() -> Set[str]:
        """Get the identifying parameters."""
        return {
            "max_new_tokens",
            "temperature",
            "do_sample",
        }

    @staticmethod
    def _model_param_names(model_name: str) -> Set[str]:
        """Get the identifying parameters."""
        # TODO: fork for different parameters for different model variants.
        return MosaicML._mpt_param_names()
    
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters."""
        return self._mpt_default_params()
    
    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate the environment."""
        try:
            # This module is supermassive so we use the transformers accelerator to load it.
            values['accelerator'] = Accelerator()
            print("[" + values["model_name"] + "] Downloading model (or fetching from cache)...")
            download_location = snapshot_download(repo_id=values["model_name"], use_auth_token=True, local_files_only=True)
            print("[" + values["model_name"] + "] Model location: " + str(download_location))
            offload_cache_location = cached_assets_path(library_name="langchain", namespace=values["model_name"], subfolder="offload")
            print("[" + values["model_name"] + "] Offload cache location: " + str(offload_cache_location))
            print("[" + values["model_name"] + "] AutoConfiguring...")
            values["config"] = AutoConfig.from_pretrained(values["model_name"], trust_remote_code=True)
            values["config"].attn_config['attn_impl'] = values["attn_impl"]
            values["tokenizer"] = AutoTokenizer.from_pretrained(values["tokenizer_name"])
            print("[" + values["model_name"] + "] Initializing empty weights for model...")
            with init_empty_weights():
                values["model"] = AutoModelForCausalLM.from_pretrained(
                    values["model_name"],
                    config=values["config"],
                    torch_dtype=values["torch_dtype"],
                    trust_remote_code=True
                )
            print("[" + values["model_name"] + "] Tying weights...")
            values["model"].tie_weights()
            print("[" + values["model_name"] + "] Dispatching checkpoint...")
            values["model"] = load_checkpoint_and_dispatch(
                values["model"], 
                download_location, 
                device_map="auto", 
                no_split_module_classes=["MPTBlock"],
                offload_folder=offload_cache_location
            )
            print("[" + values["model_name"] + "] Loaded successfully!")
        except Exception as e:
            raise Exception(f"MosaicML failed to load with error: {e}")
        return values
    
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "model": self.model_name,
            **self._default_params(),
            **{
                k: v
                for k, v in self.__dict__.items()
                if k in self._model_param_names(self.model_name)
            },
        }
    
    @property
    def _llm_type(self) -> str:
        """Return the type of llm."""
        return "mosaicml"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        r"""Call out to MosiacML's generate method via transformers.

        Args:
            prompt: The prompt to pass into the model.
            stop: A list of strings to stop generation when encountered.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

                prompt = "This is a story about a big sabre tooth tiger: "
                response = model(prompt)
        """
        text_callback = None
        if run_manager:
            text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
        text = ""
        inputs = self.tokenizer([prompt], return_tensors='pt')
        inputs = inputs.to(self.accelerator.device)
        streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True)
        generation_kwargs = dict(inputs, streamer=streamer, **self._mpt_default_params())
        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()
        text = ""
        pbar = tqdm(total=self.max_new_tokens, desc="Thinking", leave=False)
        for new_text in streamer:
            if text_callback:
                text_callback(new_text)
            text += new_text
            pbar.update(1)
        pbar.close()
        if stop is not None:
            text = enforce_stop_tokens(text, stop)
        return text

You can initialize it (StoryWriter 7B F16 takes anywhere from 6m 45s to 10m 15s to load on a slow HDD from 4 years ago, so its probably much faster on your new SSD)

llm = MosaicML(model_name='mosaicml/mpt-7b-storywriter', attn_impl='torch', torch_dtype=torch.bfloat16, max_new_tokens=200, echo=True)
/home/saber7ooth/llama-index/venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[mosaicml/mpt-7b-storywriter] Downloading model (or fetching from cache)...
[mosaicml/mpt-7b-storywriter] Model location: /home/saber7ooth/.cache/huggingface/hub/models--mosaicml--mpt-7b-storywriter/snapshots/6ba8d09107c76220faae00653ed11bcde44b3152
[mosaicml/mpt-7b-storywriter] Offload cache location: /home/saber7ooth/.cache/huggingface/assets/langchain/mosaicml--mpt-7b-storywriter/offload
[mosaicml/mpt-7b-storywriter] AutoConfiguring...
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
[mosaicml/mpt-7b-storywriter] Initializing empty weights for model...
/home/saber7ooth/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/6ba8d09107c76220faae00653ed11bcde44b3152/attention.py:148: UserWarning: Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` otherwise we recommend using `attn_impl: triton`.
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [02:55<00:00, 87.57s/it] 
[mosaicml/mpt-7b-storywriter] Tying weights...
[mosaicml/mpt-7b-storywriter] Dispatching checkpoint...
[mosaicml/mpt-7b-storywriter] Loaded successfully!

And you can text complete with it (what this model was designed to do is assist in writing stories):

llm("Tell me a short story about sabretooth tigers.")
' Or about the mummy who can\'t move."\n\n"Sabretooth tigers are extinct," I said. I knew that. I\'d read about them in school. "And I don\'t know any short stories about mummies. Let\'s talk about something else, okay?"\n\n"Fine," he said. "I\'ve got a little story about my Uncle Pike. He\'s a doctor. And he doesn\'t believe in ghosts."\n\n"No?" I said. "I thought you said he was a doctor."\n\n"He is," he said. "He\'s a doctor of medicine. And a doctor of philosophy, too. He\'s a doctor of everything. But he doesn\'t believe in ghosts. And anyway, he doesn\'t believe in ghosts because he\'s a doctor, not because he\'s a doctor of medicine."\n\n"Oh," I said. I don\'t like to argue with anyone. "I guess ghosts would'

The model does as it says. It takes a little blub that you give it, and it makes a story by continuing. If you'd like to see your original prompt, you can change this line:

        streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True)

And set skip_prompt to False.

Final notes:

Have fun playing around with this and of course if you want to do the full transformers implementation, with all the properties, have at (that's torture writing that much pydantic boilerplate...)

I also recommend making the Model loader itself a Singleton class so that if someone decides to call the constructor on this more than once, it won't load twice...

I have not tried llamacpp LangChain wrapper and because of the new changes I doubt it works. This is a very modern model.

I like it and the little story that it came up with was funny to me. πŸ‘ The fact I was able to hack it enough to work on a dodgy home computer was enough of an achievement.

P. S.: because I know people will yell here's my pip freeze:

transformers==4.28.1
huggingface-hub==0.14.1
langchain==0.0.162
asyncio==3.4.3
colorama==0.4.6
torch==2.0.1
einops==0.6.1
accelerate==0.19.0
aiohttp==3.8.4
aiosignal==1.3.1
asttokens==2.2.1
async-timeout==4.0.2
attrs==23.1.0
backcall==0.2.0
certifi==2023.5.7
charset-normalizer==3.1.0
cmake==3.26.3
comm==0.1.3
dataclasses-json==0.5.7
debugpy==1.6.7
decorator==5.1.1
executing==1.2.0
filelock==3.12.0
frozenlist==1.3.3
fsspec==2023.5.0
greenlet==2.0.2
idna==3.4
ipykernel==6.23.0
ipython==8.13.2
jedi==0.18.2
Jinja2==3.1.2
jupyter_client==8.2.0
jupyter_core==5.3.0
lit==16.0.3
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib-inline==0.1.6
mpmath==1.3.0
multidict==6.0.4
mypy-extensions==1.0.0
nest-asyncio==1.5.6
networkx==3.1
numexpr==2.8.4
numpy==1.24.3
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
openapi-schema-pydantic==1.2.4
packaging==23.1
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.5.0
prompt-toolkit==3.0.38
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pydantic==1.10.7
Pygments==2.15.1
python-dateutil==2.8.2
PyYAML==6.0
pyzmq==25.0.2
regex==2023.5.5
requests==2.30.0
six==1.16.0
SQLAlchemy==2.0.12
stack-data==0.6.2
sympy==1.11.1
tenacity==8.2.2
tokenizers==0.13.3
tornado==6.3.1
tqdm==4.65.0
traitlets==5.9.0
triton==2.0.0
typing-inspect==0.8.0
typing_extensions==4.5.0
urllib3==2.0.2
wcwidth==0.2.6
yarl==1.9.2

And my WSL instance has CUDA 12.1 from here -> https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local

My LangChain loader isn't using the Github repo at all (but it helped a ton!)) and just barebones grabbing the weights and stuffing them in transformers. Enjoy this backyard shenanigans.

Because I'm using hf transformer accelerators and split-block locking, you can expect better performance with multiple GPUs, a better CPU, faster storage, and more RAM. But this should make the model somewhat more accessible to "hobbyists" through LangChain, at the very least.

You can even load big models with this voodoo magic, if you don't mind being patient and waiting on your prompt at home...

Luv,

~ 7oothy

saber7ooth changed discussion title from [Experiment] MPT 7B + LangChain LM + transformers.accelerator, on a HOME COMPUTER to [Experiment] MPT 7B + LangChain LM + transformers.accelerator, on a POTATO
saber7ooth changed discussion title from [Experiment] MPT 7B + LangChain LM + transformers.accelerator, on a POTATO to [Experiment] MPT 7B + LangChain Custom LLM + transformers.accelerator, on a POTATO

Sign up or log in to comment