File size: 3,658 Bytes
fef0a8d
99eb93c
fef0a8d
 
 
0c034e2
45099c6
fef0a8d
 
 
 
 
 
 
 
 
5268082
45099c6
5268082
fef0a8d
 
 
96f2f76
 
 
 
 
 
 
 
fef0a8d
45099c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5268082
e7322bf
3538bec
baefccb
 
e7322bf
 
32f0fe9
 
 
 
45099c6
 
32f0fe9
5268082
 
 
 
 
 
 
45099c6
 
5268082
 
8a86647
32f0fe9
8a86647
32f0fe9
5268082
 
 
 
 
 
 
 
201c325
5268082
 
 
32f0fe9
5268082
 
eb3d2f3
5268082
 
eb3d2f3
5268082
 
 
 
 
 
 
e7322bf
5268082
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Imports
import gradio as gr
import spaces
import torch

from PIL import Image
from decord import VideoReader, cpu
from transformers import AutoModel, AutoTokenizer

# Pre-Initialize
DEVICE = "auto"
if DEVICE == "auto":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")

# Variables
DEFAULT_INPUT = "Describe in one paragraph."
MAX_FRAMES = 64

repo = AutoModel.from_pretrained("openbmb/MiniCPM-V-2_6", torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-V-2_6", trust_remote_code=True)

css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
    visibility: hidden
}
'''

# Functions
def encode_video(video_path):
    def uniform_sample(l, n):
        gap = len(l) / n
        idxs = [int(i * gap + gap / 2) for i in range(n)]
        return [l[i] for i in idxs]

    vr = VideoReader(video_path, ctx=cpu(0))
    sample_fps = round(vr.get_avg_fps() / 1)
    frame_idx = [i for i in range(0, len(vr), sample_fps)]
    if len(frame_idx) > MAX_NUM_FRAMES:
        frame_idx = uniform_sample(frame_idx, MAX_FRAMES)
    frames = vr.get_batch(frame_idx).asnumpy()
    frames = [Image.fromarray(v.astype('uint8')) for v in frames]
    return frames
    
@spaces.GPU(duration=60)
def generate(image, video, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
    repo.to(DEVICE)

    print(image)
    print(video)
    
    if not video:
        image_data = Image.fromarray(image.astype('uint8'), 'RGB')
        inputs = [{"role": "user", "content": [image_data, instruction]}]
    else:
        video_data = encode_video(video)
        inputs = [{"role": "user", "content": video_data + [instruction]}]
    
    parameters = {
        "sampling": sampling,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "max_new_tokens": max_tokens
        "use_image_id": False,
        "max_slice_nums": 2,
    }
    
    output = repo.chat(image=None, msgs=inputs, tokenizer=tokenizer, **parameters)
    
    print(output)
    
    return output
    
def cloud():
    print("[CLOUD] | Space maintained.")

# Initialize
with gr.Blocks(css=css) as main:
    with gr.Column():
        gr.Markdown("🪄 Analyze images and caption them using state-of-the-art openbmb/MiniCPM-V-2_6.")
        
    with gr.Column():
        input = gr.Image(label="Image")
        input_2 = gr.Video(label="Video")
        instruction = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Instruction")
        sampling = gr.Checkbox(value=False, label="Sampling")
        temperature = gr.Slider(minimum=0.01, maximum=1.99, step=0.01, value=0.7, label="Temperature")
        top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8, label="Top P")
        top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=100, label="Top K")
        repetition_penalty = gr.Slider(minimum=0.01, maximum=1.99, step=0.01, value=1.05, label="Repetition Penalty")
        max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
        submit = gr.Button("▶")
        maintain = gr.Button("☁️")
        
    with gr.Column():
        output = gr.Textbox(lines=1, value="", label="Output")

    submit.click(fn=generate, inputs=[input, input_2, instruction, sampling, temperature, top_p, top_k, repetition_penalty, max_tokens], outputs=[output], queue=False)
    maintain.click(cloud, inputs=[], outputs=[], queue=False)

main.launch(show_api=True)