# -*- encoding: utf-8 -*- import os import sys import torch import argparse from transformers import AutoTokenizer from sat.model.mixins import CachedAutoregressiveMixin from sat.quantization.kernels import quantize from model import VisualGLMModel, chat from finetune_visualglm import FineTuneVisualGLMModel from sat.model import AutoModel def main(): parser = argparse.ArgumentParser() parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling') parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling') parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') parser.add_argument("--english", action='store_true', help='only output English') parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') parser.add_argument("--from_pretrained", type=str, default="visualglm-6b", help='pretrained ckpt') parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round') parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round') args = parser.parse_args() # load model model, model_args = AutoModel.from_pretrained( args.from_pretrained, args=argparse.Namespace( fp16=True, skip_init=True, use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu', )) model = model.eval() if args.quant: quantize(model.transformer, args.quant) if torch.cuda.is_available(): model = model.cuda() model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) if not args.english: print('欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序') else: print('Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.') with torch.no_grad(): while True: history = None cache_image = None if not args.english: image_path = input("请输入图像路径或URL(回车进入纯文本对话): ") else: image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ") if image_path == 'stop': break if len(image_path) > 0: query = args.prompt_en if args.english else args.prompt_zh else: if not args.english: query = input("用户:") else: query = input("User: ") while True: if query == "clear": break if query == "stop": sys.exit(0) try: response, history, cache_image = chat( image_path, model, tokenizer, query, history=history, image=cache_image, max_length=args.max_length, top_p=args.top_p, temperature=args.temperature, top_k=args.top_k, english=args.english, invalid_slices=[slice(63823, 130000)] if args.english else [] ) except Exception as e: print(e) break sep = 'A:' if args.english else '答:' print("VisualGLM-6B:"+response.split(sep)[-1].strip()) image_path = None if not args.english: query = input("用户:") else: query = input("User: ") if __name__ == "__main__": main()