Optimizing Inference Time for Chat Conversations on Falcon

#73
by humza-sami - opened

How can I optimize the inference time of my 4-bit quantized Falcon 7B, which was trained on a chat dataset using Qlora+PEFT. During inference, I loaded the model in 4 bits using the bits and bytes library. While the model performs well in inference, I've observed that the inference time increases significantly as the length of the chat grows. To give you an idea, here's an example of how I'm using the model:

First Message prompt:
< user>: Hi ..
< bot>:

2nd Message prompt:
< user>: Hi ..
< bot>: How are you ?
< user>: I am good thanks. I need your help !
< bot>:

As the length of the chat increases, the inference time sometimes doubles and can take up to 2-3 minutes per prompt. I'm using an NVIDIA RTX 3090 Ti for inference. Below is the code snippet I'm using for prediction:

MODEL_NAME = "tiiuae/falcon-7b"

bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True,
                                bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

SAVED_MODEL = "Saved_models/model_path"

saved_model_config = PeftConfig.from_pretrained(SAVED_MODEL)
saved_model = AutoModelForCausalLM.from_pretrained(saved_model_config.base_model_name_or_path,
                                             return_dict=True,
                                             quantization_config=bnb_config,
                                             device_map="auto",
                                             trust_remote_code=True
                                            )

tokenizer = AutoTokenizer.from_pretrained(saved_model_config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

saved_model = PeftModel.from_pretrained(saved_model, SAVED_MODEL)

pipeline = transformers.pipeline(
    "text-generation",
    model=saved_model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)

prompt = "< user>: Hi ..
< bot>: How are you ?
< user>: I am good thanks. I need your help ! 
< bot>: "

response = pipeline(
        prompt,
        bos_token_id=11,
        max_length=2000,
        temperature=0.7,
        top_p=0.7,
        do_sample=True,
        num_return_sequences=1,
        eos_token_id=[15564]
        )[0]['generated_text']

I would greatly appreciate any insights or suggestions on how to improve the inference time of model while dealing with longer chat interactions. Thank you in advance for your assistance!

Sign up or log in to comment