J-LAB commited on
Commit
11a83f2
·
verified ·
1 Parent(s): 69b4940

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -9
app.py CHANGED
@@ -6,15 +6,24 @@ import io
6
  from PIL import Image
7
  import subprocess
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
- #
10
- model_id = 'J-LAB/Florence-vl3'
11
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
12
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
13
 
14
  DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co/microsoft/Florence-2-large)"
15
 
16
  @spaces.GPU
17
- def run_example(task_prompt, image):
18
  inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda")
19
  generated_ids = model.generate(
20
  input_ids=inputs["input_ids"],
@@ -32,14 +41,16 @@ def run_example(task_prompt, image):
32
  )
33
  return parsed_answer
34
 
35
- def process_image(image, task_prompt):
36
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
 
 
37
  if task_prompt == 'Product Caption':
38
  task_prompt = '<MORE_DETAILED_CAPTION>'
39
  elif task_prompt == 'OCR':
40
  task_prompt = '<OCR>'
41
 
42
- results = run_example(task_prompt, image)
43
 
44
  # Remove the key and get the text value
45
  if results and task_prompt in results:
@@ -80,11 +91,16 @@ single_task_list =[
80
  'Product Caption', 'OCR'
81
  ]
82
 
 
 
 
 
83
  with gr.Blocks(css=css) as demo:
84
  gr.Markdown(DESCRIPTION)
85
  with gr.Tab(label="Product Image Select"):
86
  with gr.Row():
87
  with gr.Column():
 
88
  input_img = gr.Image(label="Input Picture")
89
  task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Product Caption")
90
  submit_btn = gr.Button(value="Submit")
@@ -108,8 +124,8 @@ with gr.Blocks(css=css) as demo:
108
  ```
109
  """)
110
 
111
- submit_btn.click(process_image, [input_img, task_prompt], [output_text])
112
 
113
  demo.load(lambda: None, inputs=None, outputs=None, js=js)
114
 
115
- demo.launch(debug=True)
 
6
  from PIL import Image
7
  import subprocess
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
+
10
+ model_ids = {
11
+ "Florence-vl2": 'J-LAB/Florence-vl2',
12
+ "Florence-vl3": 'J-LAB/Florence-vl3',
13
+ "Florence_2_F_FluxiAI_Product_Caption": 'J-LAB/Florence_2_F_FluxiAI_Product_Caption'
14
+ }
15
+
16
+ # Load model and processor based on the selected model
17
+ def load_model(model_name):
18
+ model_id = model_ids[model_name]
19
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
20
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
21
+ return model, processor
22
 
23
  DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co/microsoft/Florence-2-large)"
24
 
25
  @spaces.GPU
26
+ def run_example(model, processor, task_prompt, image):
27
  inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda")
28
  generated_ids = model.generate(
29
  input_ids=inputs["input_ids"],
 
41
  )
42
  return parsed_answer
43
 
44
+ def process_image(image, task_prompt, model_name):
45
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
46
+ model, processor = load_model(model_name)
47
+
48
  if task_prompt == 'Product Caption':
49
  task_prompt = '<MORE_DETAILED_CAPTION>'
50
  elif task_prompt == 'OCR':
51
  task_prompt = '<OCR>'
52
 
53
+ results = run_example(model, processor, task_prompt, image)
54
 
55
  # Remove the key and get the text value
56
  if results and task_prompt in results:
 
91
  'Product Caption', 'OCR'
92
  ]
93
 
94
+ model_list = [
95
+ 'Florence-vl2', 'Florence-vl3', 'Florence_2_F_FluxiAI_Product_Caption'
96
+ ]
97
+
98
  with gr.Blocks(css=css) as demo:
99
  gr.Markdown(DESCRIPTION)
100
  with gr.Tab(label="Product Image Select"):
101
  with gr.Row():
102
  with gr.Column():
103
+ model_name = gr.Dropdown(choices=model_list, label="Model", value="Florence-vl3")
104
  input_img = gr.Image(label="Input Picture")
105
  task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Product Caption")
106
  submit_btn = gr.Button(value="Submit")
 
124
  ```
125
  """)
126
 
127
+ submit_btn.click(process_image, [input_img, task_prompt, model_name], [output_text])
128
 
129
  demo.load(lambda: None, inputs=None, outputs=None, js=js)
130
 
131
+ demo.launch(debug=True)