import os import sys import fire import gradio as gr import torch import transformers from peft import PeftModel from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer from utils.callbacks import Iteratorize, Stream from utils.prompter import Prompter if torch.cuda.is_available(): device = "cuda" else: device = "cpu" try: if torch.backends.mps.is_available(): device = "mps" except: pass def main( load_8bit: bool = True, base_model: str = "decapoda-research/llama-7b-hf", lora_weights: str = "tiedong/goat-lora-7b", prompt_template: str = "goat", server_name: str = "0.0.0.0", share_gradio: bool = True, ): base_model = base_model or os.environ.get("BASE_MODEL", "") assert ( base_model ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" prompter = Prompter(prompt_template) tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') if device == "cuda": model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map="auto", ) model = PeftModel.from_pretrained( model, lora_weights, torch_dtype=torch.float16, ) elif device == "mps": model = LlamaForCausalLM.from_pretrained( base_model, device_map={"": device}, torch_dtype=torch.float16, ) model = PeftModel.from_pretrained( model, lora_weights, device_map={"": device}, torch_dtype=torch.float16, ) else: model = LlamaForCausalLM.from_pretrained( base_model, device_map={"": device}, low_cpu_mem_usage=True ) model = PeftModel.from_pretrained( model, lora_weights, device_map={"": device}, ) if not load_8bit: model.half() model.eval() if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) def evaluate( instruction, temperature=0.1, top_p=0.75, top_k=40, num_beams=4, max_new_tokens=512, stream_output=True, **kwargs, ): prompt = prompter.generate_prompt_inference(instruction) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, **kwargs, ) generate_params = { "input_ids": input_ids, "generation_config": generation_config, "return_dict_in_generate": True, "output_scores": True, "max_new_tokens": max_new_tokens, } if stream_output: # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator, # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. def generate_with_callback(callback=None, **kwargs): kwargs.setdefault( "stopping_criteria", transformers.StoppingCriteriaList() ) kwargs["stopping_criteria"].append( Stream(callback_func=callback) ) with torch.no_grad(): model.generate(**kwargs) def generate_with_streaming(**kwargs): return Iteratorize( generate_with_callback, kwargs, callback=None ) with generate_with_streaming(**generate_params) as generator: for output in generator: # new_tokens = len(output) - len(input_ids[0]) decoded_output = tokenizer.decode(output) if output[-1] in [tokenizer.eos_token_id]: break yield prompter.get_response(decoded_output) return # early return for stream_output # Without streaming with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=max_new_tokens, ) s = generation_output.sequences[0] output = tokenizer.decode(s, skip_special_tokens=True).strip() yield prompter.get_response(output) gr.Interface( fn=evaluate, inputs=[ gr.components.Textbox( lines=2, label="Arithmetic", placeholder="What is 63303235 + 20239503", ), gr.components.Slider( minimum=0, maximum=1, value=0.1, label="Temperature" ), gr.components.Slider( minimum=0, maximum=1, value=0.75, label="Top p" ), gr.components.Slider( minimum=0, maximum=100, step=1, value=40, label="Top k" ), gr.components.Slider( minimum=1, maximum=4, step=1, value=4, label="Beams" ), gr.components.Slider( minimum=1, maximum=1024, step=1, value=512, label="Max tokens" ), gr.components.Checkbox(label="Stream output"), ], outputs=[ gr.inputs.Textbox( lines=5, label="Output", ) ], title="Goat-loRA-7b", description="Goat-LoRA-7b is a 7B-parameter LLaMA finetuned to perform arithmetic tasks, including addition, subtraction, multiplication, and division of integers. It is trained on a synthetic dataset (https://github.com/liutiedong/goat) and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/liutiedong/goat).", # noqa: E501 ).queue().launch(server_name="0.0.0.0", share=share_gradio) if __name__ == "__main__": fire.Fire(main)