wenjiao commited on
Commit
625f2d5
1 Parent(s): 090da45

Update src/display/utils.py

Browse files
Files changed (1) hide show
  1. src/display/utils.py +7 -6
src/display/utils.py CHANGED
@@ -240,16 +240,15 @@ class WeightDtype(Enum):
240
  int2 = ModelDetails("int2")
241
  int3 = ModelDetails("int3")
242
  int4 = ModelDetails("int4")
 
243
  nf4 = ModelDetails("nf4")
244
  fp4 = ModelDetails("fp4")
245
- fp16 = ModelDetails("float16")
246
  bf16 = ModelDetails("bfloat16")
247
- fp32 = ModelDetails("float32")
248
 
249
  Unknown = ModelDetails("?")
250
 
251
-
252
-
253
  def from_str(weight_dtype):
254
  if weight_dtype in ["int2"]:
255
  return WeightDtype.int2
@@ -257,6 +256,8 @@ class WeightDtype(Enum):
257
  return WeightDtype.int3
258
  if weight_dtype in ["int4"]:
259
  return WeightDtype.int4
 
 
260
  if weight_dtype in ["nf4"]:
261
  return WeightDtype.nf4
262
  if weight_dtype in ["fp4"]:
@@ -264,11 +265,11 @@ class WeightDtype(Enum):
264
  if weight_dtype in ["All"]:
265
  return WeightDtype.all
266
  if weight_dtype in ["float16"]:
267
- return WeightDtype.fp16
268
  if weight_dtype in ["bfloat16"]:
269
  return WeightDtype.bf16
270
  if weight_dtype in ["float32"]:
271
- return WeightDtype.fp32
272
  return WeightDtype.Unknown
273
 
274
  class ComputeDtype(Enum):
 
240
  int2 = ModelDetails("int2")
241
  int3 = ModelDetails("int3")
242
  int4 = ModelDetails("int4")
243
+ int8 = ModelDetails("int8")
244
  nf4 = ModelDetails("nf4")
245
  fp4 = ModelDetails("fp4")
246
+ f16 = ModelDetails("float16")
247
  bf16 = ModelDetails("bfloat16")
248
+ f32 = ModelDetails("float32")
249
 
250
  Unknown = ModelDetails("?")
251
 
 
 
252
  def from_str(weight_dtype):
253
  if weight_dtype in ["int2"]:
254
  return WeightDtype.int2
 
256
  return WeightDtype.int3
257
  if weight_dtype in ["int4"]:
258
  return WeightDtype.int4
259
+ if weight_dtype in ["int8"]:
260
+ return WeightDtype.int8
261
  if weight_dtype in ["nf4"]:
262
  return WeightDtype.nf4
263
  if weight_dtype in ["fp4"]:
 
265
  if weight_dtype in ["All"]:
266
  return WeightDtype.all
267
  if weight_dtype in ["float16"]:
268
+ return WeightDtype.f16
269
  if weight_dtype in ["bfloat16"]:
270
  return WeightDtype.bf16
271
  if weight_dtype in ["float32"]:
272
+ return WeightDtype.f32
273
  return WeightDtype.Unknown
274
 
275
  class ComputeDtype(Enum):