""" Usage: python3 -m fastchat.serve.huggingface_api --model ~/model_weights/vicuna-7b/ """ import argparse import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM from fastchat.conversation import get_default_conv_template, compute_skip_echo_len from fastchat.serve.inference import load_model @torch.inference_mode() def main(args): model, tokenizer = load_model( args.model_path, args.device, args.num_gpus, args.max_gpu_memory, args.load_8bit, debug=args.debug, ) msg = args.message conv = get_default_conv_template(args.model_path).copy() conv.append_message(conv.roles[0], msg) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) output_ids = model.generate( torch.as_tensor(inputs.input_ids).cuda(), do_sample=True, temperature=0.7, max_new_tokens=1024, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] skip_echo_len = compute_skip_echo_len(args.model_path, conv, prompt) outputs = outputs[skip_echo_len:] print(f"{conv.roles[0]}: {msg}") print(f"{conv.roles[1]}: {outputs}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model-path", type=str, default="facebook/opt-350m", help="The path to the weights", ) parser.add_argument( "--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda" ) parser.add_argument("--num-gpus", type=str, default="1") parser.add_argument( "--max-gpu-memory", type=str, help="The maximum memory per gpu. Use a string like '13Gib'", ) parser.add_argument( "--load-8bit", action="store_true", help="Use 8-bit quantization." ) parser.add_argument( "--conv-template", type=str, default=None, help="Conversation prompt template." ) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--debug", action="store_true") parser.add_argument("--message", type=str, default="Hello! Who are you?") args = parser.parse_args() main(args)