File size: 2,746 Bytes
8ec545b
 
 
 
 
 
 
 
 
2de9ebb
8ec545b
 
 
 
 
 
 
 
 
 
 
 
 
2de9ebb
 
 
 
 
8ec545b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2de9ebb
7d01662
2de9ebb
7d01662
 
 
 
 
 
8ec545b
2de9ebb
8ec545b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import gradio as gr
import base64
import requests
import io
from PIL import Image
import numpy as np
URL = os.environ['URL']

def sketch_to_text(image, api_key):
    if image is None or not isinstance(image, dict) or 'composite' not in image:
        return "Please draw something first."
    
    # Extract the image data from the dictionary
    image_data = image['composite']
    # Convert the image data to a PIL Image
    pil_image = Image.fromarray(image_data.astype(np.uint8))
    
    # Convert the image to base64
    buffered = io.BytesIO()
    pil_image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    if api_key:
        API_KEY = api_key
    else:
        API_KEY = os.environ['API_KEY']
        
    # Prepare the API request
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }
    payload = {
        "model": "Llama-3.2-11B-Vision-Instruct",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "You are playing a game of pictionary. Please guess what I am trying to draw. Answer in short words only."
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{img_str}"
                        }
                    }
                ]
            }
        ],
        "max_tokens": 300
    }

    # Make the API request
    response = requests.post(URL, headers=headers, json=payload)
    
    if response.status_code == 200:
        return response.json()["choices"][0]["message"]["content"]
    else:
        return f"Error: {response.status_code}, {response.text}"

# Create the Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Pictionary with Llama3.2 Instruct")
    gr.Markdown("Draw something and let Llama3.2 guess it! [Powered by SambaNova Cloud, Get Your API Key Here](https://cloud.sambanova.ai/apis)")
    with gr.Row(scale=1):
        api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")
    with gr.Row(scale=1):
        with gr.Column(scale=1):
            output = gr.Textbox(label="Description", lines=5)
        
        with gr.Column(scale=1):
            input_image = gr.ImageEditor()
    
    input_image.change(fn=sketch_to_text, inputs=[input_image, api_key], outputs=output)
    
    gr.Markdown("How to use: 1. Draw your sketch in the box above. 2. See guessing in real time. Have fun sketching!")

# Launch the app
iface.launch()