File size: 4,216 Bytes
76578bc
70a6a62
76578bc
70a6a62
97eae29
76578bc
70a6a62
 
 
 
 
76578bc
70a6a62
76578bc
 
 
 
 
 
70a6a62
 
 
 
 
 
 
 
 
 
 
 
 
76578bc
70a6a62
 
 
 
 
 
 
 
76578bc
70a6a62
 
 
 
6e0c709
70a6a62
6e0c709
76578bc
70a6a62
 
 
 
 
 
 
 
 
 
 
 
 
97eae29
 
70a6a62
97eae29
70a6a62
 
 
 
 
 
6e0c709
76578bc
 
70a6a62
76578bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6978e20
f651569
6978e20
f651569
3388a44
76578bc
 
 
6e0c709
 
 
76578bc
 
 
 
97eae29
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import requests
import json
import os
import time

API_KEY = os.getenv("API_KEY")
if not API_KEY:
    raise ValueError("API_KEY environment variable must be set")

def process_image_stream(image_path, prompt, max_tokens=512):
    """
    Process image with streaming response via HTTP
    """
    if not image_path:
        yield "Please upload an image first."
        return
    
    try:
        # Read and prepare image file
        with open(image_path, 'rb') as img_file:
            files = {
                'image': ('image.jpg', img_file, 'image/jpeg')
            }
            data = {
                'prompt': prompt,
                'task': 'instruct',
                'max_tokens': max_tokens
            }
            headers = {
                'X-API-Key': API_KEY
            }
            
            # Make streaming request
            response = requests.post(
                'https://nexa-omni.nexa4ai.com/process-image/',
                files=files,
                data=data,
                headers=headers,
                stream=True
            )
            
            if response.status_code != 200:
                yield f"Error: Server returned status code {response.status_code}"
                return

            # Initialize response and token counter
            response_text = ""
            token_count = 0
            
            # Process the streaming response
            for line in response.iter_lines():
                if line:
                    line = line.decode('utf-8')
                    if line.startswith('data: '):
                        try:
                            data = json.loads(line[6:])  # Skip 'data: ' prefix
                            if data["status"] == "generating":
                                # Skip first three tokens if they match specific patterns
                                if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]:
                                    token_count += 1
                                    continue
                                response_text += data["token"]
                                # Add explicit Gradio update
                                gr.update(value=response_text)
                                yield response_text
                                time.sleep(0.01)  # Small delay to ensure updates are visible
                            elif data["status"] == "complete":
                                break
                            elif data["status"] == "error":
                                yield f"Error: {data['error']}"
                                break
                        except json.JSONDecodeError:
                            continue
                
    except Exception as e:
        yield f"Error processing request: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=process_image_stream,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.Textbox(
            label="Question", 
            placeholder="Ask a question about the image...",
            value="Describe this image"
        ),
        gr.Slider(
            minimum=50,
            maximum=200,
            value=200,
            step=1,
            label="Max Tokens"
        )
    ],
    outputs=gr.Textbox(label="Response", interactive=False),
    title="NEXA OmniVLM-968M",
    description=f"""
    Model Repo: <a href="https://huggingface.co/NexaAIDev/OmniVLM-968M">NexaAIDev/OmniVLM-968M</a>

    *Model updated on Nov 21, 2024\n
    Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries.
    """,
    examples=[
        ["example_images/example_1.jpg", "What kind of cat is this?", 128],
        ["example_images/example_2.jpg", "What color is this dress? ", 128],
        ["example_images/example_3.jpg", "What is this image about?", 128],
    ]
)

if __name__ == "__main__":
    # Configure the queue for better streaming performance
    demo.queue(
        max_size=20,
    ).launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,
        max_threads=1
    )