--- language: - zh pipeline_tag: image-to-text tags: - vit - gpt --- ```python from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer import torch from PIL import Image import pathlib import pandas as pd import numpy as np from IPython.core.display import HTML import os import requests class Image2Caption(object): def __init__(self ,model_path = "nlpconnect/vit-gpt2-image-captioning", device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), overwrite_encoder_checkpoint_path = None, overwrite_token_model_path = None ): assert type(overwrite_token_model_path) == type("") or overwrite_token_model_path is None assert type(overwrite_encoder_checkpoint_path) == type("") or overwrite_encoder_checkpoint_path is None if overwrite_token_model_path is None: overwrite_token_model_path = model_path if overwrite_encoder_checkpoint_path is None: overwrite_encoder_checkpoint_path = model_path self.device = device self.model = VisionEncoderDecoderModel.from_pretrained(model_path) self.feature_extractor = ViTFeatureExtractor.from_pretrained(overwrite_encoder_checkpoint_path) self.tokenizer = AutoTokenizer.from_pretrained(overwrite_token_model_path) self.model = self.model.to(self.device) def predict_to_df(self, image_paths): img_caption_pred = self.predict_step(image_paths) img_cation_df = pd.DataFrame(list(zip(image_paths, img_caption_pred))) img_cation_df.columns = ["img", "caption"] return img_cation_df #img_cation_df.to_html(escape=False, formatters=dict(Country=path_to_image_html)) def predict_step(self ,image_paths, max_length = 128, num_beams = 4): gen_kwargs = {"max_length": max_length, "num_beams": num_beams} images = [] for image_path in image_paths: #i_image = Image.open(image_path) if image_path.startswith("http"): i_image = Image.open( requests.get(image_path, stream=True).raw ) else: i_image = Image.open(image_path) if i_image.mode != "RGB": i_image = i_image.convert(mode="RGB") images.append(i_image) pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values pixel_values = pixel_values.to(self.device) output_ids = self.model.generate(pixel_values, **gen_kwargs) preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] return preds def path_to_image_html(path): return '' i2c_tiny_zh_obj = Image2Caption("svjack/vit-gpt-diffusion-zh", overwrite_encoder_checkpoint_path = "google/vit-base-patch16-224", overwrite_token_model_path = "IDEA-CCNL/Wenzhong-GPT2-110M" ) i2c_tiny_zh_obj.predict_step( ["https://datasets-server.huggingface.co/assets/poloclub/diffusiondb/--/2m_all/train/28/image/image.jpg"] ) ```
```json ['"一个年轻男人的肖像,由Greg Rutkowski创作"。Artstation上的趋势"。"《刀锋战士》的艺术作品"。高度细节化。"电影般的灯光"。超现实主义。锐利的焦点。辛烷�'] ```