Yair commited on
Commit
4034b15
1 Parent(s): 18c5ad2
Files changed (2) hide show
  1. .gitattributes +1 -25
  2. app.py +24 -4
.gitattributes CHANGED
@@ -1,27 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ftz filter=lfs diff=lfs merge=lfs -text
6
- *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
- *.joblib filter=lfs diff=lfs merge=lfs -text
9
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
- *.model filter=lfs diff=lfs merge=lfs -text
11
- *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.onnx filter=lfs diff=lfs merge=lfs -text
13
- *.ot filter=lfs diff=lfs merge=lfs -text
14
- *.parquet filter=lfs diff=lfs merge=lfs -text
15
- *.pb filter=lfs diff=lfs merge=lfs -text
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
- *.pth filter=lfs diff=lfs merge=lfs -text
18
- *.rar filter=lfs diff=lfs merge=lfs -text
19
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
- *.tar.* filter=lfs diff=lfs merge=lfs -text
21
- *.tflite filter=lfs diff=lfs merge=lfs -text
22
- *.tgz filter=lfs diff=lfs merge=lfs -text
23
- *.wasm filter=lfs diff=lfs merge=lfs -text
24
- *.xz filter=lfs diff=lfs merge=lfs -text
25
- *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
1
  *.h5 filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
3
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,7 +1,27 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer
4
 
 
 
5
 
6
+ def create_caption_transformer(img):
7
+ """
8
+ create_caption_transformer() create a caption for an image using a transformer model
9
+ that was trained on 'Flickr image dataset'
10
+ :param img: a numpy array of the image
11
+ :return: a string of the image caption
12
+ """
13
+
14
+ sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu')
15
+ caption_ids = model.generate(sample, max_length=15)[0] # TODO: take care of the caption length
16
+ caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
17
+ caption_text = caption_text.split('.')[0]
18
+ return caption_text
19
+
20
+
21
+ model = VisionEncoderDecoderModel.from_pretrained(os.getcwd() + '/transformer').to('cpu')
22
+ feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
23
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
24
+ iface = gr.Interface(fn=create_caption_transformer,
25
+ inputs="image",
26
+ outputs='text',
27
+ ).launch(share=True)