vishalkatheriya
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,52 @@
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
import inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# Initialize session state to block re-running
|
5 |
if 'has_run' not in st.session_state:
|
6 |
st.session_state.has_run = False
|
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
import inference
|
4 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
5 |
+
from PIL import Image
|
6 |
+
import requests
|
7 |
+
import copy
|
8 |
+
import os
|
9 |
+
from unittest.mock import patch
|
10 |
+
from transformers.dynamic_module_utils import get_imports
|
11 |
+
import torch
|
12 |
+
|
13 |
+
#remove flash_attn for load model in cpu
|
14 |
+
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
15 |
+
if not str(filename).endswith("modeling_florence2.py"):
|
16 |
+
return get_imports(filename)
|
17 |
+
imports = get_imports(filename)
|
18 |
+
imports.remove("flash_attn")
|
19 |
+
return imports
|
20 |
+
|
21 |
+
# Initialize session state for model loading and to block re-running
|
22 |
+
if 'model_loaded' not in st.session_state:
|
23 |
+
st.session_state.model_loaded = False
|
24 |
+
|
25 |
+
# Function to load the model (e.g., Florence-2 model)
|
26 |
+
def load_model():
|
27 |
+
# Simulate model loading process
|
28 |
+
model_id = "microsoft/Florence-2-large"
|
29 |
+
#processor loading
|
30 |
+
st.session_state.processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True)
|
31 |
+
|
32 |
+
# Load the model normally
|
33 |
+
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): # workaround for unnecessary flash_attn requirement
|
34 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True)
|
35 |
+
|
36 |
+
# Apply dynamic quantization
|
37 |
+
Qmodel = torch.quantization.quantize_dynamic(
|
38 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
39 |
+
)
|
40 |
+
del model
|
41 |
+
st.session_state.model = Qmodel
|
42 |
+
st.session_state.model_loaded = True
|
43 |
+
st.write("model loaded complete")
|
44 |
+
# Load the model only once
|
45 |
+
if not st.session_state.model_loaded:
|
46 |
+
with st.spinner('Loading model...'):
|
47 |
+
load_model()
|
48 |
+
|
49 |
+
|
50 |
# Initialize session state to block re-running
|
51 |
if 'has_run' not in st.session_state:
|
52 |
st.session_state.has_run = False
|