Fluxi-IA / app.py
J-LAB's picture
Update app.py
11a83f2 verified
raw
history blame
4.51 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
import io
from PIL import Image
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_ids = {
"Florence-vl2": 'J-LAB/Florence-vl2',
"Florence-vl3": 'J-LAB/Florence-vl3',
"Florence_2_F_FluxiAI_Product_Caption": 'J-LAB/Florence_2_F_FluxiAI_Product_Caption'
}
# Load model and processor based on the selected model
def load_model(model_name):
model_id = model_ids[model_name]
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
return model, processor
DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co/microsoft/Florence-2-large)"
@spaces.GPU
def run_example(model, processor, task_prompt, image):
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return parsed_answer
def process_image(image, task_prompt, model_name):
image = Image.fromarray(image) # Convert NumPy array to PIL Image
model, processor = load_model(model_name)
if task_prompt == 'Product Caption':
task_prompt = '<MORE_DETAILED_CAPTION>'
elif task_prompt == 'OCR':
task_prompt = '<OCR>'
results = run_example(model, processor, task_prompt, image)
# Remove the key and get the text value
if results and task_prompt in results:
output_text = results[task_prompt]
else:
output_text = ""
# Convert newline characters to HTML line breaks
output_text = output_text.replace("\n\n", "<br><br>").replace("\n", "<br>")
return output_text
css = """
#output {
overflow: auto;
border: 1px solid #ccc;
padding: 10px;
background-color: rgb(31 41 55);
color: #fff;
}
"""
js = """
function adjustHeight() {
var outputElement = document.getElementById('output');
outputElement.style.height = 'auto'; // Reset height to auto to get the actual content height
var height = outputElement.scrollHeight + 'px'; // Get the scrollHeight
outputElement.style.height = height; // Set the height
}
// Attach the adjustHeight function to the click event of the submit button
document.querySelector('button').addEventListener('click', function() {
setTimeout(adjustHeight, 500); // Adjust the height after a small delay to ensure content is loaded
});
"""
single_task_list =[
'Product Caption', 'OCR'
]
model_list = [
'Florence-vl2', 'Florence-vl3', 'Florence_2_F_FluxiAI_Product_Caption'
]
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Product Image Select"):
with gr.Row():
with gr.Column():
model_name = gr.Dropdown(choices=model_list, label="Model", value="Florence-vl3")
input_img = gr.Image(label="Input Picture")
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Product Caption")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.HTML(label="Output Text", elem_id="output")
gr.Markdown("""
## How to use via API
To use this model via API, you can follow the example code below:
```python
!pip install gradio_client
from gradio_client import Client, handle_file
client = Client("J-LAB/Fluxi-IA")
result = client.predict(
image=handle_file('https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'),
api_name="/process_image"
)
print(result)
```
""")
submit_btn.click(process_image, [input_img, task_prompt, model_name], [output_text])
demo.load(lambda: None, inputs=None, outputs=None, js=js)
demo.launch(debug=True)