autonomous019 commited on
Commit
79e0b51
1 Parent(s): 2ac8087

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTConfig, ViTForImageClassification
2
+ from transformers import ViTFeatureExtractor
3
+ from PIL import Image
4
+ import requests
5
+ import matplotlib.pyplot as plt
6
+ import gradio as gr
7
+ from gradio.mix import Parallel
8
+ from transformers import ImageClassificationPipeline, PerceiverForImageClassificationConvProcessing, PerceiverFeatureExtractor
9
+ from transformers import VisionEncoderDecoderModel
10
+ from transformers import AutoTokenizer
11
+ import torch
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ LogitsProcessorList,
15
+ MinLengthLogitsProcessor,
16
+ StoppingCriteriaList,
17
+ MaxLengthCriteria,
18
+ )
19
+
20
+ # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb
21
+ # option 1: load with randomly initialized weights (train from scratch)
22
+
23
+ #tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
24
+ #model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
25
+
26
+
27
+ config = ViTConfig(num_hidden_layers=12, hidden_size=768)
28
+ model = ViTForImageClassification(config)
29
+
30
+ #print(config)
31
+
32
+ feature_extractor = ViTFeatureExtractor()
33
+ # or, to load one that corresponds to a checkpoint on the hub:
34
+ #feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
35
+
36
+ #the following gets called by classify_image()
37
+ feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-conv")
38
+ model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
39
+ #google/vit-base-patch16-224, deepmind/vision-perceiver-conv
40
+ image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
41
+
42
+ def create_story(text_seed):
43
+ #tokenizer = AutoTokenizer.from_pretrained("gpt2")
44
+ #model = AutoModelForCausalLM.from_pretrained("gpt2")
45
+
46
+ #eleutherAI gpt-3 based
47
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
48
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
49
+
50
+ # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
51
+ model.config.pad_token_id = model.config.eos_token_id
52
+
53
+ #input_prompt = "It might be possible to"
54
+ input_prompt = text_seed
55
+ input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
56
+
57
+ # instantiate logits processors
58
+ logits_processor = LogitsProcessorList(
59
+ [
60
+ MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
61
+ ]
62
+ )
63
+ stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=100)])
64
+
65
+ outputs = model.greedy_search(
66
+ input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
67
+ )
68
+
69
+ result_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
70
+ return result_text
71
+
72
+
73
+
74
+
75
+
76
+
77
+ def self_caption(image):
78
+ repo_name = "ydshieh/vit-gpt2-coco-en"
79
+ #test_image = "cats.jpg"
80
+ test_image = image
81
+ #url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
82
+ #test_image = Image.open(requests.get(url, stream=True).raw)
83
+ #test_image.save("cats.png")
84
+
85
+ feature_extractor2 = ViTFeatureExtractor.from_pretrained(repo_name)
86
+ tokenizer = AutoTokenizer.from_pretrained(repo_name)
87
+ model2 = VisionEncoderDecoderModel.from_pretrained(repo_name)
88
+ pixel_values = feature_extractor2(test_image, return_tensors="pt").pixel_values
89
+ print("Pixel Values")
90
+ print(pixel_values)
91
+ # autoregressively generate text (using beam search or other decoding strategy)
92
+ generated_ids = model2.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
93
+
94
+ # decode into text
95
+ preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)
96
+ preds = [pred.strip() for pred in preds]
97
+ print("Predictions")
98
+ print(preds)
99
+ print("The preds type is : ",type(preds))
100
+ pred_keys = ["Prediction"]
101
+ pred_value = preds
102
+
103
+ pred_dictionary = dict(zip(pred_keys, pred_value))
104
+ print("Pred dictionary")
105
+ print(pred_dictionary)
106
+ #return(pred_dictionary)
107
+ preds = ' '.join(preds)
108
+ story = create_story(preds)
109
+ story = ' '.join(story)
110
+ return story
111
+
112
+
113
+ def classify_image(image):
114
+ results = image_pipe(image)
115
+
116
+ print("RESULTS")
117
+ print(results)
118
+ # convert to format Gradio expects
119
+ output = {}
120
+ for prediction in results:
121
+ predicted_label = prediction['label']
122
+ score = prediction['score']
123
+ output[predicted_label] = score
124
+ print("OUTPUT")
125
+ print(output)
126
+ return output
127
+
128
+
129
+ image = gr.inputs.Image(type="pil")
130
+ label = gr.outputs.Label(num_top_classes=5)
131
+ examples = [ ["cats.jpg"], ["batter.jpg"],["drinkers.jpg"] ]
132
+ title = "Generate a Story from an Image"
133
+ description = "Demo for classifying images with Perceiver IO. To use it, simply upload an image and click 'submit', a story is autogenerated as well"
134
+ article = "<p style='text-align: center'></p>"
135
+
136
+ img_info1 = gr.Interface(
137
+ fn=classify_image,
138
+ inputs=image,
139
+ outputs=label,
140
+ )
141
+
142
+ img_info2 = gr.Interface(
143
+ fn=self_caption,
144
+ inputs=image,
145
+ #outputs=label,
146
+ outputs = [
147
+ gr.outputs.Textbox(label = 'Story')
148
+ ],
149
+ )
150
+
151
+ Parallel(img_info1,img_info2, inputs=image, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)
152
+ #Parallel(img_info1,img_info2, inputs=image, outputs=label, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)
153
+