matjarm commited on
Commit
b90a4c8
·
1 Parent(s): 043d7db
Files changed (2) hide show
  1. main.py +119 -0
  2. requirement.txt → requirements.txt +0 -0
main.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
3
+ import torch
4
+ from PIL import Image
5
+
6
+ model1 = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
7
+ feature_extractor1 = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
+ tokenizer1 = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
9
+
10
+ device1 = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model1.to(device1)
12
+
13
+
14
+
15
+ max_length = 16
16
+ num_beams = 4
17
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
18
+
19
+ def image_to_text_model_1(image_url):
20
+ raw_image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
21
+
22
+ pixel_values = feature_extractor1(images=[raw_image], return_tensors="pt").pixel_values
23
+ pixel_values = pixel_values.to(device1)
24
+
25
+ output_ids = model1.generate(pixel_values, **gen_kwargs)
26
+
27
+ preds = tokenizer1.batch_decode(output_ids, skip_special_tokens=True)
28
+ preds = [pred.strip() for pred in preds]
29
+ return preds
30
+
31
+ def bytes_to_text_model_1(bts):
32
+ pixel_values = feature_extractor1(images=[bts], return_tensors="pt").pixel_values
33
+ pixel_values = pixel_values.to(device1)
34
+
35
+ output_ids = model1.generate(pixel_values, **gen_kwargs)
36
+
37
+ preds = tokenizer1.batch_decode(output_ids, skip_special_tokens=True)
38
+ preds = [pred.strip() for pred in preds]
39
+ print(preds[0])
40
+
41
+
42
+ import requests
43
+ from PIL import Image
44
+ from transformers import BlipProcessor, BlipForConditionalGeneration
45
+ import torch
46
+
47
+ device2 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+ processor2 = BlipProcessor.from_pretrained("noamrot/FuseCap")
49
+ model2 = BlipForConditionalGeneration.from_pretrained("noamrot/FuseCap").to(device2)
50
+
51
+
52
+ def image_to_text_model_2(img_url):
53
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
54
+ text = "a picture of "
55
+ inputs = processor2(raw_image, text, return_tensors="pt").to(device2)
56
+
57
+ out = model2.generate(**inputs, num_beams = 3)
58
+ print(processor2.decode(out[0], skip_special_tokens=True))
59
+
60
+ def bytes_to_text_model_2(byts):
61
+ text = "a picture of "
62
+ inputs = processor2(byts, text, return_tensors="pt").to(device2)
63
+
64
+ out = model2.generate(**inputs, num_beams = 3)
65
+ print(processor2.decode(out[0], skip_special_tokens=True))
66
+
67
+
68
+
69
+ import requests
70
+ from PIL import Image
71
+ from transformers import BlipProcessor, BlipForConditionalGeneration
72
+
73
+ processor3 = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
74
+ model3 = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
75
+
76
+ def image_to_text_model_3(img_url):
77
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
78
+ text = "a picture of"
79
+ inputs = processor3(raw_image, text, return_tensors="pt")
80
+ inputs = processor3(raw_image, return_tensors="pt")
81
+
82
+ out = model3.generate(**inputs)
83
+ print(processor3.decode(out[0], skip_special_tokens=True))
84
+
85
+ def bytes_to_text_model_3(byts):
86
+ text = "a picture of"
87
+ inputs = processor3(byts, text, return_tensors="pt")
88
+ inputs = processor3(byts, return_tensors="pt")
89
+
90
+ out = model3.generate(**inputs)
91
+ print(processor3.decode(out[0], skip_special_tokens=True))
92
+
93
+
94
+ import cv2
95
+
96
+ def FrameCapture(path):
97
+ vidObj = cv2.VideoCapture(path)
98
+ count = 0
99
+ success = 1
100
+
101
+ while success:
102
+ success, image = vidObj.read()
103
+
104
+ if count % 20 == 0:
105
+
106
+ print("NEW FRAME")
107
+ print("MODEL 1")
108
+ bytes_to_text_model_1(image)
109
+ print("MODEL 2")
110
+ bytes_to_text_model_2(image)
111
+ print("MODEL 3")
112
+ bytes_to_text_model_3(image)
113
+
114
+ print("\n\n")
115
+
116
+ count += 1
117
+
118
+
119
+ FrameCapture("animation.mp4")
requirement.txt → requirements.txt RENAMED
File without changes