File size: 3,898 Bytes
ade70cf 1d51385 d69fd19 d7f29ce 2ff3a1c 39ae23a ade70cf d59f119 ade70cf 3b56f9e ca16909 beec895 d59f119 9c53151 5ae9be1 53581ac d59f119 144ba4b 53581ac d59f119 5ae9be1 1d51385 5ae9be1 2ff3a1c 84d0e49 6172e67 d59f119 dd36999 6172e67 92e51e9 53581ac 144ba4b 53581ac 6172e67 63ed30f 1d51385 6172e67 144ba4b 6172e67 8dc80bf 63ed30f 6172e67 63ed30f 53581ac 6172e67 e5ad09f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
import io
from PIL import Image
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = 'J-LAB/Florence_2_B_FluxiAI_Product_Caption'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co/microsoft/Florence-2-large)"
@spaces.GPU
def run_example(task_prompt, image):
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return parsed_answer
def process_image(image, task_prompt):
image = Image.fromarray(image) # Convert NumPy array to PIL Image
if task_prompt == 'Product Caption':
task_prompt = '<PC>'
elif task_prompt == 'OCR':
task_prompt = '<OCR>'
results = run_example(task_prompt, image)
# Remove the key and get the text value
if results and task_prompt in results:
output_text = results[task_prompt]
else:
output_text = ""
# Convert newline characters to HTML line breaks
output_text = output_text.replace("\n\n", "<br><br>").replace("\n", "<br>")
return output_text
css = """
#output {
overflow: auto;
border: 1px solid #ccc;
padding: 10px;
background-color: rgb(31 41 55);
color: #fff;
}
"""
js = """
function adjustHeight() {
var outputElement = document.getElementById('output');
outputElement.style.height = 'auto'; // Reset height to auto to get the actual content height
var height = outputElement.scrollHeight + 'px'; // Get the scrollHeight
outputElement.style.height = height; // Set the height
}
// Attach the adjustHeight function to the click event of the submit button
document.querySelector('button').addEventListener('click', function() {
setTimeout(adjustHeight, 500); // Adjust the height after a small delay to ensure content is loaded
});
"""
single_task_list =[
'Product Caption', 'OCR'
]
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Product Image Select"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Product Caption")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.HTML(label="Output Text", elem_id="output")
gr.Markdown("""
## How to use via API
To use this model via API, you can follow the example code below:
```python
!pip install gradio_client
from gradio_client import Client, handle_file
client = Client("J-LAB/Fluxi-IA")
result = client.predict(
image=handle_file('https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'),
api_name="/process_image"
)
print(result)
```
""")
submit_btn.click(process_image, [input_img, task_prompt], [output_text])
demo.load(lambda: None, inputs=None, outputs=None, js=js)
demo.launch(debug=True) |