donb-hf commited on
Commit
4038683
1 Parent(s): 640c7a9

add examples

Browse files
__pycache__/model_utils.cpython-310.pyc ADDED
Binary file (2.35 kB). View file
 
app.py CHANGED
@@ -3,6 +3,10 @@ import logging
3
  import sys
4
  from config import WEAVE_PROJECT, WANDB_API_KEY
5
  import weave
 
 
 
 
6
 
7
  weave.init(WEAVE_PROJECT)
8
 
@@ -186,9 +190,15 @@ def detect_brain_tumor_florence2(image, seg_input, debug: bool = True):
186
 
187
  return (image_with_bboxes, annotations)
188
 
 
 
 
 
 
 
189
  INTRO_TEXT="# 🔬🧠 OmniScience -- Agentic Imaging Analysis 🤖🧫"
190
 
191
- with gr.Blocks(css="style.css") as demo:
192
  gr.Markdown(INTRO_TEXT)
193
  with gr.Tab("Object Detection - Owl V2"):
194
  with gr.Row():
@@ -296,7 +306,29 @@ with gr.Blocks(css="style.css") as demo:
296
  fn=detect_brain_tumor_dino,
297
  inputs=seg_inputs,
298
  outputs=seg_outputs,
299
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  if __name__ == "__main__":
302
  demo.queue(max_size=10).launch(debug=True)
 
3
  import sys
4
  from config import WEAVE_PROJECT, WANDB_API_KEY
5
  import weave
6
+ from model_utils import get_model_summary, install_flash_attn
7
+
8
+ # Install required package
9
+ install_flash_attn()
10
 
11
  weave.init(WEAVE_PROJECT)
12
 
 
190
 
191
  return (image_with_bboxes, annotations)
192
 
193
+ def handle_model_summary(model_name):
194
+ model_summary, error_message = get_model_summary(model_name)
195
+ if error_message:
196
+ return error_message, ""
197
+ return model_summary, ""
198
+
199
  INTRO_TEXT="# 🔬🧠 OmniScience -- Agentic Imaging Analysis 🤖🧫"
200
 
201
+ with gr.Blocks(theme="sudeepshouche/minimalist") as demo:
202
  gr.Markdown(INTRO_TEXT)
203
  with gr.Tab("Object Detection - Owl V2"):
204
  with gr.Row():
 
306
  fn=detect_brain_tumor_dino,
307
  inputs=seg_inputs,
308
  outputs=seg_outputs,
309
+ )
310
+
311
+ with gr.Tab("Model Explorer"):
312
+ gr.Markdown("## Retrieve and Display Model Architecture")
313
+ model_name_input = gr.Textbox(label="Model Name", placeholder="Enter the model name to retrieve its architecture...")
314
+ vision_examples = gr.Examples(
315
+ examples=[
316
+ ["facebook/sam-vit-huge"],
317
+ ["google/owlv2-base-patch16-ensemble"],
318
+ ["IDEA-Research/grounding-dino-base"],
319
+ ["microsoft/Florence-2-large-ft"],
320
+ ["google/paligemma-3b-mix-224"],
321
+ ["llava-hf/llava-v1.6-mistral-7b-hf"],
322
+ ["vikhyatk/moondream2"],
323
+ ["microsoft/Phi-3-vision-128k-instruct"],
324
+ ["HuggingFaceM4/idefics2-8b-chatty"]
325
+ ],
326
+ inputs=model_name_input
327
+ )
328
+ model_output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True)
329
+ error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True)
330
+ model_submit_button = gr.Button("Submit")
331
+ model_submit_button.click(fn=handle_model_summary, inputs=model_name_input, outputs=[model_output, error_output])
332
 
333
  if __name__ == "__main__":
334
  demo.queue(max_size=10).launch(debug=True)
model_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+ import torch
4
+ from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration, Owlv2ForObjectDetection, GroundingDinoForObjectDetection, SamModel
5
+ import spaces
6
+
7
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8
+
9
+ def install_flash_attn():
10
+ subprocess.run(
11
+ "pip install flash-attn --no-build-isolation",
12
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
13
+ shell=True,
14
+ )
15
+
16
+ ARCHITECTURE_MAP = {
17
+ "LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
18
+ "LlavaForConditionalGeneration": LlavaForConditionalGeneration,
19
+ "PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
20
+ "Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
21
+ "Owlv2ForObjectDetection": Owlv2ForObjectDetection,
22
+ "GroundingDinoForObjectDetection": GroundingDinoForObjectDetection,
23
+ "SamModel": SamModel,
24
+ "AutoModelForCausalLM": AutoModelForCausalLM
25
+ }
26
+
27
+
28
+ @spaces.GPU
29
+ def get_model_summary(model_name):
30
+ try:
31
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
32
+ architecture = config.architectures[0]
33
+ quantization_config = getattr(config, 'quantization_config', None)
34
+
35
+ if quantization_config:
36
+ bnb_config = BitsAndBytesConfig(
37
+ load_in_4bit=quantization_config.get('load_in_4bit', False),
38
+ load_in_8bit=quantization_config.get('load_in_8bit', False),
39
+ bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
40
+ bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
41
+ bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
42
+ llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
43
+ llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
44
+ llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
45
+ llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
46
+ )
47
+ else:
48
+ bnb_config = None
49
+
50
+ model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
51
+ model = model_class.from_pretrained(
52
+ model_name, config=bnb_config, trust_remote_code=True
53
+ )
54
+
55
+ if model and not quantization_config:
56
+ model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
57
+
58
+ model_summary = str(model) if model else "Model architecture not found."
59
+ config_content = config.to_json_string() if config else "Configuration not found."
60
+ return f"## Model Architecture\n\n{model_summary}\n\n## Configuration\n\n{config_content}", ""
61
+ except ValueError as ve:
62
+ return "", f"ValueError: {ve}"
63
+ except EnvironmentError as ee:
64
+ return "", f"EnvironmentError: {ee}"
65
+ except Exception as e:
66
+ return "", str(e)
requirements.txt CHANGED
@@ -4,4 +4,6 @@ pillow
4
  pillow-heif
5
  weave
6
  huggingface-hub
7
- gradio
 
 
 
4
  pillow-heif
5
  weave
6
  huggingface-hub
7
+ gradio
8
+ transformers
9
+ spaces