Vokturz commited on
Commit
a0b9dac
·
1 Parent(s): 74c26d6

improved how memory is managed

Browse files
Files changed (1) hide show
  1. src/app.py +8 -2
src/app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  from utils import extract_from_url, get_model, calculate_memory
4
  import plotly.express as px
5
  import numpy as np
 
6
 
7
  st.set_page_config(page_title='Can you run it? LLM version', layout="wide", initial_sidebar_state="expanded")
8
 
@@ -64,8 +65,13 @@ if not model_name:
64
 
65
  model_name = extract_from_url(model_name)
66
  if model_name not in st.session_state:
 
 
 
 
67
  model = get_model(model_name, library="transformers", access_token=access_token)
68
- st.session_state[model_name] = (model, calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"]))
 
69
 
70
 
71
  gpu_vendor = st.sidebar.selectbox("GPU Vendor", ["NVIDIA", "AMD", "Intel"])
@@ -86,7 +92,7 @@ lora_pct = st.sidebar.slider("LoRa % trainable parameters", 0.1, 100.0, 2.0, ste
86
 
87
  st.sidebar.dataframe(gpu_spec.T)
88
 
89
- memory_table = pd.DataFrame(st.session_state[model_name][1]).set_index('dtype')
90
  memory_table['LoRA Fine-Tuning (GB)'] = (memory_table["Total Size (GB)"] +
91
  (memory_table["Parameters (Billion)"]* lora_pct/100 * (16/8)*4)) * 1.2
92
 
 
3
  from utils import extract_from_url, get_model, calculate_memory
4
  import plotly.express as px
5
  import numpy as np
6
+ import gc
7
 
8
  st.set_page_config(page_title='Can you run it? LLM version', layout="wide", initial_sidebar_state="expanded")
9
 
 
65
 
66
  model_name = extract_from_url(model_name)
67
  if model_name not in st.session_state:
68
+ if 'actual_model' in st.session_state:
69
+ del st.session_state[st.session_state['actual_model']]
70
+ del st.session_state['actual_model']
71
+ gc.collect()
72
  model = get_model(model_name, library="transformers", access_token=access_token)
73
+ st.session_state[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
74
+ st.session_state['actual_model'] = model_name
75
 
76
 
77
  gpu_vendor = st.sidebar.selectbox("GPU Vendor", ["NVIDIA", "AMD", "Intel"])
 
92
 
93
  st.sidebar.dataframe(gpu_spec.T)
94
 
95
+ memory_table = pd.DataFrame(st.session_state[model_name]).set_index('dtype')
96
  memory_table['LoRA Fine-Tuning (GB)'] = (memory_table["Total Size (GB)"] +
97
  (memory_table["Parameters (Billion)"]* lora_pct/100 * (16/8)*4)) * 1.2
98