AlexWortega commited on
Commit
a9069a7
1 Parent(s): 445e60e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rudalle import get_tokenizer, get_vae
3
+ from rudalle.utils import seed_everything
4
+
5
+ import sys
6
+ from rudolph.model.utils import get_i2t_attention_mask, get_t2t_attention_mask
7
+ from rudolph.model import get_rudolph_model, ruDolphModel, FP16Module
8
+ from rudolph.pipelines import generate_codebooks, self_reranking_by_image, self_reranking_by_text, show, generate_captions, generate_texts from rudolph.pipelines import zs_clf
9
+ import gradio as gr
10
+ from rudolph import utils
11
+ device = 'cuda'
12
+
13
+ model = get_rudolph_model('350M', fp16=True, device='cuda')
14
+ tokenizer = get_tokenizer()
15
+ vae = get_vae(dwt=False).to(device)
16
+
17
+
18
+
19
+
20
+
21
+
22
+ # Download human-readable labels for ImageNet.
23
+
24
+
25
+
26
+ def classify_image(inp):
27
+ print(type(inp))
28
+ inp = Image.fromarray(inp)
29
+ texts = generate_captions(inp, tokenizer, model, vae, template=template, top_k=16, captions_num=1, bs=16, top_p=0.6, seed=43, temperature=0.8)
30
+
31
+
32
+ return texts
33
+
34
+ image = gr.inputs.Image(shape=(128, 128))
35
+ label = gr.outputs.Label(num_top_classes=3)
36
+
37
+
38
+ iface = gr.Interface(fn=classify_image, inputs=image, outputs="text",examples=[
39
+ ['b9c277a3.jpeg']])
40
+ iface.launch(share=True)