基于FormalGeo7K的结构识别模型
快速开始
在运行脚本之前,首先安装如下必要的依赖。
pip install --upgrade pip
pip install torch transformers==4.40.0
pip install sentencepiece protobuf
pip install accelerate pillow
pip install ninja
pip install packaging
pip install flash-attn --no-build-isolation
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings
import numpy as np
# set device
device = 'cuda' # or cpu
torch.set_default_device(device)
# create model
model = AutoModelForCausalLM.from_pretrained(
'NaughtyDog97/GeoFormalizer',
torch_dtype=torch.float16, # float32 for cpu
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
'NaughtyDog97/GeoFormalizer',
trust_remote_code=True)
# text prompt
img_path = 'sample/4927.png'
prompt = 'Based on the image, first describe what you see in the figure, then predict the construction_cdl and image_cdl and calibrate it.'
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
# image, sample images can be found in images folder
image = Image.open(img_path).convert('RGB')
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device)
# generate
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
num_beams=1,
max_new_tokens=3500,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=None,
use_cache=True
)[0]
respones = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
print(respones)
我们的模型支持的识别方式有如下几种:
- 自然语言描述:
- Describe what you see in the figure.
- Tell me what you observe in the image.
- 使用自然语言描述这幅图像。
- 只预测construction_cdl
- Based on the image, predict the construction_cdl.
- 根据图像识别出construction_cdl。
- Based on the image, predict the construction_cdl and calibrate it.
- 根据图像识别出construction_cdl并进行矫正。
- Based on the image, first describe what you see in the figure, then predict the construction_cdl.
- 根据图像,首先描述图像,之后识别出construction_cdl。
- Based on the image, first describe what you see in the figure, then predict the construction_cdl and calibrate it.
- 根据图像,首先描述图像,之后识别出construction_cdl并进行矫正。
- 只预测image_cdl
- Based on the image, predict the image_cdl.
- 根据图像识别出image_cdl。
- Based on the image, predict the image_cdl and calibrate it.
- 根据图像识别出image_cdl并进行矫正。
- Based on the image, first describe what you see in the figure, then predict the image_cdl.
- 根据图像,首先描述图像,之后识别出image_cdl。
- Based on the image, first describe what you see in the figure, then predict the image_cdl and calibrate it.
- 根据图像,首先描述图像,之后识别出image_cdl并进行矫正。
- 同时预测construction_cdl和image_cdl
- Based on the image, predict the construction_cdl and image_cdl.
- 根据图像识别出construction_cdl和image_cdl。
- Based on the image, first predict the construction_cdl and image_cdl and calibrate it.
- 根据图像识别出construction_cdl和image_cdl并进行矫正。
- Based on the image, first describe what you see in the figure, then predict the construction_cdl and image_cdl.
- 根据图像,首先描述图像,之后识别出construction_cdl和image_cdl。
- Based on the image, first describe what you see in the figure, then predict the construction_cdl and image_cdl and calibrate it.
- 根据图像,首先描述图像,之后识别出construction_cdl和image_cdl并矫正。
Performance
ConsCdlAcc | ConsCdlPerfect | ImgCdlAcc | ImgCdlPerfect | BothPerfect | |
---|---|---|---|---|---|
siglip-0.4B-qwen2-0.5B | 90.254 | 72.286 | 92.880 | 84.381 | 65.048 |
- Downloads last month
- 11
Inference API (serverless) does not yet support model repos that contain custom code.