Spaces:
Build error
Build error
AlexWortega
commited on
Commit
•
a9069a7
1
Parent(s):
445e60e
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from rudalle import get_tokenizer, get_vae
|
3 |
+
from rudalle.utils import seed_everything
|
4 |
+
|
5 |
+
import sys
|
6 |
+
from rudolph.model.utils import get_i2t_attention_mask, get_t2t_attention_mask
|
7 |
+
from rudolph.model import get_rudolph_model, ruDolphModel, FP16Module
|
8 |
+
from rudolph.pipelines import generate_codebooks, self_reranking_by_image, self_reranking_by_text, show, generate_captions, generate_texts from rudolph.pipelines import zs_clf
|
9 |
+
import gradio as gr
|
10 |
+
from rudolph import utils
|
11 |
+
device = 'cuda'
|
12 |
+
|
13 |
+
model = get_rudolph_model('350M', fp16=True, device='cuda')
|
14 |
+
tokenizer = get_tokenizer()
|
15 |
+
vae = get_vae(dwt=False).to(device)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
# Download human-readable labels for ImageNet.
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
def classify_image(inp):
|
27 |
+
print(type(inp))
|
28 |
+
inp = Image.fromarray(inp)
|
29 |
+
texts = generate_captions(inp, tokenizer, model, vae, template=template, top_k=16, captions_num=1, bs=16, top_p=0.6, seed=43, temperature=0.8)
|
30 |
+
|
31 |
+
|
32 |
+
return texts
|
33 |
+
|
34 |
+
image = gr.inputs.Image(shape=(128, 128))
|
35 |
+
label = gr.outputs.Label(num_top_classes=3)
|
36 |
+
|
37 |
+
|
38 |
+
iface = gr.Interface(fn=classify_image, inputs=image, outputs="text",examples=[
|
39 |
+
['b9c277a3.jpeg']])
|
40 |
+
iface.launch(share=True)
|