import os import sys import fire import torch from peft import PeftModel from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer from utils.prompter import Prompter if torch.cuda.is_available(): device = "cuda" else: device = "cpu" try: if torch.backends.mps.is_available(): device = "mps" except: pass def main( load_8bit: bool = False, base_model: str = "", lora_weights: str = "DSMI/LLaMA-E/7b", prompt_template: str = "", ): print("lora_weights: " + str(lora_weights)) base_model = base_model or os.environ.get("BASE_MODEL", "") prompter = Prompter(prompt_template) tokenizer = LlamaTokenizer.from_pretrained(base_model) if device == "cuda": model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map="auto", ) model = PeftModel.from_pretrained( model, lora_weights, torch_dtype=torch.float16, ) elif device == "mps": model = LlamaForCausalLM.from_pretrained( base_model, device_map={"": device}, torch_dtype=torch.float16, ) model = PeftModel.from_pretrained( model, lora_weights, device_map={"": device}, torch_dtype=torch.float16, ) else: model = LlamaForCausalLM.from_pretrained( base_model, device_map={"": device}, low_cpu_mem_usage=True ) model = PeftModel.from_pretrained( model, lora_weights, device_map={"": device}, ) model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk model.config.bos_token_id = 1 model.config.eos_token_id = 2 if not load_8bit: model.half() # seems to fix bugs for some users. model.eval() if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) def evaluate( instruction, input=None, temperature=0.1, top_p=0.75, top_k=40, num_beams=4, max_new_tokens=128, **kwargs, ): prompt = prompter.generate_prompt(instruction, input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, **kwargs, ) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=max_new_tokens, ) s = generation_output.sequences[0] output = tokenizer.decode(s) return prompter.get_response(output).split("")[0] print() instruction = "Where can I buy the handmade jewellery?" print("Instruction:", instruction) print("Response:", evaluate(instruction)) print() instruction = "Generate an ad for the following product." input = "Emerald Teardrop Necklace.May Birthstone Pendant.Dainty Gift for Her.925 Sterling Silver.Spring Sale" print("Instruction:", instruction) print("Input:", input) print("Response:", evaluate(instruction, input)) print() if __name__ == "__main__": fire.Fire(main)