wenjiao commited on
Commit
d57a7e2
1 Parent(s): a623a77

add dynamic and no quantization

Browse files
Files changed (1) hide show
  1. app.py +74 -3
app.py CHANGED
@@ -49,6 +49,59 @@ from src.tools.plots import (
49
  # Start ephemeral Spaces on PRs (see config in README.md)
50
  #enable_space_ci()
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def restart_space():
53
  API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)
54
 
@@ -115,9 +168,15 @@ def update_table(
115
  double_quant: str,
116
  group_dtype: str
117
  ):
 
 
 
118
 
119
- if weight_dtype == 'All':
120
- weight_dtype = ['int2', 'int3', 'int4', 'nf4', 'fp4']
 
 
 
121
  else:
122
  weight_dtype = [weight_dtype]
123
 
@@ -285,7 +344,7 @@ with demo:
285
  filter_columns_type = gr.CheckboxGroup(
286
  label="Quantization types",
287
  choices=[t.to_str() for t in QuantType],
288
- value=[t.to_str() for t in QuantType],
289
  interactive=True,
290
  elem_id="filter-columns-type",
291
  )
@@ -378,6 +437,18 @@ with demo:
378
  demo.load(load_query, inputs=[], outputs=[search_bar, hidden_search_bar])
379
 
380
  """
 
 
 
 
 
 
 
 
 
 
 
 
381
  for selector in [shown_columns, filter_columns_type, filter_columns_precision, filter_columns_size, filter_columns_parameters, hide_models, filter_columns_computeDtype, filter_columns_weightDtype, filter_columns_doubleQuant, filter_columns_groupDtype]:
382
  selector.change(
383
  update_table,
 
49
  # Start ephemeral Spaces on PRs (see config in README.md)
50
  #enable_space_ci()
51
 
52
+ precision_to_dtype = {
53
+ "2bit": ["int2"],
54
+ "3bit": ["int3"],
55
+ "4bit": ["int4", "nf4", "fp4"],
56
+ "?": ["?"]
57
+ }
58
+
59
+ current_weightDtype = ["All", "int2", "int3", "int4", "nf4", "fp4", "?"]
60
+
61
+ # Global variable to store the selected dtypes
62
+ selected_dtypes = ["All"]
63
+ init_select = False
64
+
65
+ def quant_update_Weight_Dtype(selected_precisions):
66
+ global current_weightDtype
67
+ if '✖ None' in selected_precisions:
68
+ if not any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in current_weightDtype):
69
+ current_weightDtype += ['float16', 'bfloat16', 'float32']
70
+ else:
71
+ if any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in current_weightDtype):
72
+ current_weightDtype = [dtype for dtype in current_weightDtype if dtype not in ['float16', 'bfloat16', 'float32']]
73
+ return gr.Dropdown.update(choices=current_weightDtype, value="All")
74
+
75
+
76
+ def update_Weight_Dtype(selected_precisions):
77
+ global selected_dtypes
78
+ global current_weightDtype
79
+ global init_select
80
+ init_select = True
81
+
82
+ if not selected_precisions: # If no precision is selected, return "All"
83
+ selected_dtypes = ["All"]
84
+ return gr.Dropdown.update(choices=["All"], value="All")
85
+
86
+ selected_dtypes_set = set()
87
+ for precision in selected_precisions:
88
+ if precision in precision_to_dtype:
89
+ selected_dtypes_set.update(precision_to_dtype[precision])
90
+
91
+
92
+ # Convert set to sorted list to maintain order
93
+ selected_dtypes = sorted(selected_dtypes_set)
94
+ if any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in current_weightDtype) and not any(dtype in ['float16', 'bfloat16', 'float32'] for dtype in selected_dtypes):
95
+ selected_dtypes += ['float16', 'bfloat16', 'float32']
96
+ # Add "All" to the beginning of the list for display purposes
97
+ display_choices = ["All"] + selected_dtypes
98
+
99
+
100
+ current_weightDtype = display_choices
101
+ return gr.Dropdown.update(choices=display_choices, value="All")
102
+
103
+
104
+
105
  def restart_space():
106
  API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)
107
 
 
168
  double_quant: str,
169
  group_dtype: str
170
  ):
171
+ global init_select
172
+ global current_weightDtype
173
+
174
 
175
+ if selected_dtypes == ['All']:
176
+ weight_dtype = current_weightDtype
177
+ elif weight_dtype == ['All'] or weight_dtype == 'All' or init_select:
178
+ weight_dtype = selected_dtypes
179
+ init_select = False
180
  else:
181
  weight_dtype = [weight_dtype]
182
 
 
344
  filter_columns_type = gr.CheckboxGroup(
345
  label="Quantization types",
346
  choices=[t.to_str() for t in QuantType],
347
+ value=[t.to_str() for t in QuantType if t != QuantType.QuantType_None],
348
  interactive=True,
349
  elem_id="filter-columns-type",
350
  )
 
437
  demo.load(load_query, inputs=[], outputs=[search_bar, hidden_search_bar])
438
 
439
  """
440
+ filter_columns_precision.change(
441
+ update_Weight_Dtype,
442
+ [filter_columns_precision],
443
+ [filter_columns_weightDtype]
444
+ )
445
+
446
+ filter_columns_type.change(
447
+ quant_update_Weight_Dtype,
448
+ [filter_columns_type],
449
+ [filter_columns_weightDtype]
450
+ )
451
+
452
  for selector in [shown_columns, filter_columns_type, filter_columns_precision, filter_columns_size, filter_columns_parameters, hide_models, filter_columns_computeDtype, filter_columns_weightDtype, filter_columns_doubleQuant, filter_columns_groupDtype]:
453
  selector.change(
454
  update_table,