Modification of example code for quantization and load models one by one

#3
by emra - opened

I modifies the example code so it can use bitsandbytes quantization and also load models one by one so it doesn't OOM, its a bit slower of course. (You may add this to readme.md if you like)
(btw if you want load quantized models at the beggining and not free the memory, you need to move load_model function out of while: and delete the added model_regenerator.cpu() del model_critic gc.collect() torch.cuda.empty_cache() lines)

import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig

model_path_actor = "/home/ubuntu/llm/HelixNet/actor"
model_path_critic = "/home/ubuntu/llm/HelixNet/critic"
model_path_regenerator = "/home/ubuntu/llm/HelixNet/regenerator"

nf4_config = BitsAndBytesConfig(
   load_in_8bit=False,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

def load_model_quant(model_path):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=nf4_config
    )
    return model

def load_model(model_path):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map="cuda",
        load_in_4bit=False,
        trust_remote_code=True,
    )
    return model

def load_tokenizer(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    return tokenizer


tokenizer_actor = load_tokenizer(model_path_actor)
tokenizer_critic = load_tokenizer(model_path_critic)
tokenizer_regenerator = load_tokenizer(model_path_regenerator)

def generate_text(instruction, model, tokenizer):
    tokens = tokenizer.encode(instruction)
    tokens = torch.LongTensor(tokens).unsqueeze(0)
    tokens = tokens.to("cuda")

    instance = {
        "input_ids": tokens,
        "top_p": 0.3,
        "temperature": 0.75,
        "generate_len": 1024,
        "top_k": 50,
    }

    length = len(tokens[0])
    with torch.no_grad():
        rest = model.generate(
            input_ids=tokens,
            max_length=length + instance["generate_len"],
            use_cache=True,
            do_sample=True,
            top_p=instance["top_p"],
            temperature=instance["temperature"],
            top_k=instance["top_k"],
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
    output = rest[0][length:]
    string = tokenizer.decode(output, skip_special_tokens=True)
    return f"{string}"

system_prompt = "You are HelixNet. Elaborate on the topic using a Tree of Thoughts and backtrack when necessary to construct a clear, cohesive Chain of Thought reasoning. Always answer without hesitation."
  

while True:
    user_input = input("You: ")
    
    model_actor = load_model(model_path_actor)
    prompt_actor = f"SYSTEM: {system_prompt} \nUSER: {user_input} \nASSISTANT: "
    actor_response = generate_text(prompt_actor, model_actor, tokenizer_actor)
    print(f"ACTOR: {actor_response}\n\n")
    model_actor.cpu()
    del model_actor
    gc.collect()
    torch.cuda.empty_cache()
   
    model_critic = load_model(model_path_critic)
    prompt_critic = f"SYSTEM: {system_prompt} \nUSER: {user_input} \nRESPONSE: {actor_response} \nCRITIQUE:"
    critic_response = generate_text(prompt_critic, model_critic, tokenizer_critic)
    print(f"CRITIQUE: {critic_response}\n\n")
    model_critic.cpu()
    del model_critic
    gc.collect()
    torch.cuda.empty_cache()

    model_regenerator = load_model(model_path_regenerator)
    prompt_regenerator = f"SYSTEM: {system_prompt} \nUSER: {user_input} \nRESPONSE: {actor_response} \nCRITIQUE: {critic_response} \nREGENERATOR:"
    regenerator_response = generate_text(prompt_regenerator, model_regenerator, tokenizer_regenerator)
    print(f"REGENERATION: {regenerator_response}")
    model_regenerator.cpu()
    del model_regenerator
    gc.collect()
    torch.cuda.empty_cache()

Nice one! Thanks for sharing!

migtissera changed discussion status to closed

Sign up or log in to comment