File size: 3,898 Bytes
ade70cf
1d51385
d69fd19
d7f29ce
 
2ff3a1c
39ae23a
 
ade70cf
d59f119
 
 
ade70cf
3b56f9e
ca16909
beec895
d59f119
 
9c53151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae9be1
53581ac
d59f119
144ba4b
 
 
 
53581ac
d59f119
 
5ae9be1
 
 
1d51385
5ae9be1
 
 
 
 
2ff3a1c
84d0e49
6172e67
 
 
 
d59f119
dd36999
 
6172e67
 
 
92e51e9
 
 
 
 
 
 
 
 
 
 
 
 
53581ac
144ba4b
 
 
53581ac
6172e67
 
63ed30f
1d51385
6172e67
 
144ba4b
6172e67
 
8dc80bf
63ed30f
 
 
 
6172e67
63ed30f
 
 
 
 
 
 
 
 
 
 
 
 
53581ac
6172e67
e5ad09f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces

import io
from PIL import Image
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

model_id = 'J-LAB/Florence_2_B_FluxiAI_Product_Caption'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co/microsoft/Florence-2-large)"

@spaces.GPU
def run_example(task_prompt, image):
    inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        early_stopping=False,
        do_sample=False,
        num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )
    return parsed_answer

def process_image(image, task_prompt):
    image = Image.fromarray(image)  # Convert NumPy array to PIL Image
    if task_prompt == 'Product Caption':
        task_prompt = '<PC>'
    elif task_prompt == 'OCR':
        task_prompt = '<OCR>'
    
    results = run_example(task_prompt, image)
    
    # Remove the key and get the text value
    if results and task_prompt in results:
        output_text = results[task_prompt]
    else:
        output_text = ""

    # Convert newline characters to HTML line breaks
    output_text = output_text.replace("\n\n", "<br><br>").replace("\n", "<br>")

    return output_text

css = """
  #output {
    overflow: auto; 
    border: 1px solid #ccc; 
    padding: 10px;
    background-color: rgb(31 41 55);
    color: #fff; 
  }
"""

js = """
function adjustHeight() {
    var outputElement = document.getElementById('output');
    outputElement.style.height = 'auto';  // Reset height to auto to get the actual content height
    var height = outputElement.scrollHeight + 'px';  // Get the scrollHeight
    outputElement.style.height = height;  // Set the height
}

// Attach the adjustHeight function to the click event of the submit button
document.querySelector('button').addEventListener('click', function() {
    setTimeout(adjustHeight, 500);  // Adjust the height after a small delay to ensure content is loaded
});
"""

single_task_list =[
    'Product Caption', 'OCR'
]

with gr.Blocks(css=css) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Tab(label="Product Image Select"):
        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label="Input Picture")
                task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Product Caption")
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_text = gr.HTML(label="Output Text", elem_id="output")
        
        gr.Markdown("""
        ## How to use via API
        To use this model via API, you can follow the example code below:

        ```python
        !pip install gradio_client
        from gradio_client import Client, handle_file

        client = Client("J-LAB/Fluxi-IA")
        result = client.predict(
            image=handle_file('https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'),
            api_name="/process_image"
        )
        print(result)
        ```
        """)
        
        submit_btn.click(process_image, [input_img, task_prompt], [output_text])

    demo.load(lambda: None, inputs=None, outputs=None, js=js)

demo.launch(debug=True)