Fluxi-IA / app.py
J-LAB's picture
Update app.py
ca24932 verified
raw
history blame
2.96 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
import io
from PIL import Image
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = 'J-LAB/Florence_2_B_FluxiAI_Product_Caption'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co/microsoft/Florence-2-large) ]"
@spaces.GPU
def run_example(task_prompt, image):
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return parsed_answer
def process_image(image):
image = Image.fromarray(image) # Convert NumPy array to PIL Image
task_prompt = '<PC>'
results = run_example(task_prompt, image)
# Remove the key and get the text value
if results and task_prompt in results:
output_text = results[task_prompt]
else:
output_text = ""
# Convert newline characters to HTML line breaks
output_text = output_text.replace("\n\n", "<br><br>").replace("\n", "<br>")
return output_text
css = """
#output {
height: 250px;
overflow: auto;
border: 1px solid #ccc;
padding: 10px;
background-color: rgb(31 41 55);
color: #fff;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Product Image Select"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.HTML(label="Output Text", elem_id="output")
gr.Markdown("""
## How to use via API
To use this model via API, you can follow the example code below:
```python
!pip install gradio_client
from gradio_client import Client, handle_file
client = Client("J-LAB/Fluxi-IA")
result = client.predict(
image=handle_file('https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'),
api_name="/process_image"
)
print(result)
```
""")
submit_btn.click(process_image, [input_img], [output_text])
demo.launch(debug=True)