| import argparse |
| import torch |
| from PIL import Image |
| from transformers import AutoTokenizer |
| from earthdial.model.internvl_chat import InternVLChatModel |
| from earthdial.train.dataset import build_transform |
|
|
| def run_single_inference(args): |
| print(f"Loading model and tokenizer from Hugging Face: {args.checkpoint}") |
| tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) |
| model = InternVLChatModel.from_pretrained( |
| args.checkpoint, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.bfloat16, |
| device_map="auto" if args.auto else None, |
| load_in_8bit=args.load_in_8bit, |
| load_in_4bit=args.load_in_4bit |
| ).eval() |
|
|
| if not args.load_in_8bit and not args.load_in_4bit and not args.auto: |
| model = model.cuda() |
|
|
| image = Image.open(args.image_path).convert("RGB") |
| image_size = model.config.force_image_size or model.config.vision_config.image_size |
| transform = build_transform(is_train=False, input_size=image_size, normalize_type='imagenet') |
| pixel_values = transform(image).unsqueeze(0).cuda().to(torch.bfloat16) |
|
|
| generation_config = { |
| "num_beams": args.num_beams, |
| "max_new_tokens": 100, |
| "min_new_tokens": 1, |
| "do_sample": args.temperature > 0, |
| "temperature": args.temperature, |
| } |
|
|
| answer = model.chat( |
| tokenizer=tokenizer, |
| pixel_values=pixel_values, |
| question=args.question, |
| generation_config=generation_config, |
| verbose=True |
| ) |
|
|
| print("\n=== Inference Result ===") |
| print(f"Question: {args.question}") |
| print(f"Answer: {answer}") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--checkpoint', type=str, required=True, help='Model repo ID on Hugging Face Hub') |
| parser.add_argument('--image-path', type=str, required=True, help='Path to input image') |
| parser.add_argument('--question', type=str, required=True, help='Visual question to ask') |
| parser.add_argument('--num-beams', type=int, default=5) |
| parser.add_argument('--temperature', type=float, default=0.0) |
| parser.add_argument('--load-in-8bit', action='store_true') |
| parser.add_argument('--load-in-4bit', action='store_true') |
| parser.add_argument('--auto', action='store_true') |
|
|
| args = parser.parse_args() |
| run_single_inference(args) |
|
|