svjack commited on
Commit
9a85fd1
1 Parent(s): 7d9efca

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +91 -0
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ pipeline_tag: image-to-text
5
+ tags:
6
+ - vit
7
+ - gpt
8
+ ---
9
+ ```python
10
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
11
+ import torch
12
+ from PIL import Image
13
+ import pathlib
14
+ import pandas as pd
15
+ import numpy as np
16
+ from IPython.core.display import HTML
17
+ import os
18
+ import requests
19
+
20
+ class Image2Caption(object):
21
+ def __init__(self ,model_path = "nlpconnect/vit-gpt2-image-captioning",
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
23
+ overwrite_encoder_checkpoint_path = None,
24
+ overwrite_token_model_path = None
25
+ ):
26
+ assert type(overwrite_token_model_path) == type("") or overwrite_token_model_path is None
27
+ assert type(overwrite_encoder_checkpoint_path) == type("") or overwrite_encoder_checkpoint_path is None
28
+ if overwrite_token_model_path is None:
29
+ overwrite_token_model_path = model_path
30
+ if overwrite_encoder_checkpoint_path is None:
31
+ overwrite_encoder_checkpoint_path = model_path
32
+ self.device = device
33
+ self.model = VisionEncoderDecoderModel.from_pretrained(model_path)
34
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(overwrite_encoder_checkpoint_path)
35
+ self.tokenizer = AutoTokenizer.from_pretrained(overwrite_token_model_path)
36
+ self.model = self.model.to(self.device)
37
+
38
+ def predict_to_df(self, image_paths):
39
+ img_caption_pred = self.predict_step(image_paths)
40
+ img_cation_df = pd.DataFrame(list(zip(image_paths, img_caption_pred)))
41
+ img_cation_df.columns = ["img", "caption"]
42
+ return img_cation_df
43
+ #img_cation_df.to_html(escape=False, formatters=dict(Country=path_to_image_html))
44
+
45
+ def predict_step(self ,image_paths, max_length = 128, num_beams = 4):
46
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
47
+ images = []
48
+ for image_path in image_paths:
49
+ #i_image = Image.open(image_path)
50
+ if image_path.startswith("http"):
51
+ i_image = Image.open(
52
+ requests.get(image_path, stream=True).raw
53
+ )
54
+ else:
55
+ i_image = Image.open(image_path)
56
+
57
+ if i_image.mode != "RGB":
58
+ i_image = i_image.convert(mode="RGB")
59
+ images.append(i_image)
60
+
61
+ pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
62
+ pixel_values = pixel_values.to(self.device)
63
+
64
+ output_ids = self.model.generate(pixel_values, **gen_kwargs)
65
+
66
+ preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
67
+ preds = [pred.strip() for pred in preds]
68
+ return preds
69
+
70
+ def path_to_image_html(path):
71
+ return '<img src="'+ path + '" width="60" >'
72
+
73
+ i2c_tiny_zh_obj = Image2Caption("svjack/vit-gpt-diffusion-zh",
74
+ overwrite_encoder_checkpoint_path = "google/vit-base-patch16-224",
75
+ overwrite_token_model_path = "IDEA-CCNL/Wenzhong-GPT2-110M"
76
+ )
77
+
78
+ i2c_tiny_zh_obj.predict_step(
79
+ ["https://datasets-server.huggingface.co/assets/poloclub/diffusiondb/--/2m_all/train/28/image/image.jpg"]
80
+ )
81
+ ```
82
+
83
+ </br>
84
+
85
+ <div><img src='https://datasets-server.huggingface.co/assets/poloclub/diffusiondb/--/2m_all/train/28/image/image.jpg' width="550" height="450" /></div>
86
+
87
+ </br>
88
+
89
+ ```json
90
+ ['"一个年轻男人的肖像,由Greg Rutkowski创作"。Artstation上的趋势"。"《刀锋战士》的艺术作品"。高度细节化。"电影般的灯光"。超现实主义。锐利的焦点。辛烷�']
91
+ ```