p1atdev commited on
Commit
8ae3068
1 Parent(s): 63d1c50

chore: remote flag of model backend

Browse files
Files changed (2) hide show
  1. app.py +38 -23
  2. requirements.txt +3 -1
app.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
- # from optimum.onnxruntime import ORTModelForCausalLM
9
 
10
  from huggingface_hub import login
11
 
@@ -17,8 +17,14 @@ MODEL_NAME = (
17
  else "p1atdev/dart-v1-sft"
18
  )
19
  HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
 
 
 
 
 
20
 
21
  assert isinstance(MODEL_NAME, str)
 
22
 
23
  tokenizer = AutoTokenizer.from_pretrained(
24
  MODEL_NAME,
@@ -30,19 +36,19 @@ model = {
30
  MODEL_NAME,
31
  token=HF_READ_TOKEN,
32
  ),
33
- # "ort": ORTModelForCausalLM.from_pretrained(
34
- # MODEL_NAME,
35
- # ),
36
- # "ort_qantized": ORTModelForCausalLM.from_pretrained(
37
- # MODEL_NAME,
38
- # file_name="model_quantized.onnx",
39
- # ),
40
  }
41
 
42
  MODEL_BACKEND_MAP = {
43
  "Default": "default",
44
- # "ONNX (normal)": "ort",
45
- # "ONNX (quantized)": "ort_qantized",
46
  }
47
 
48
  try:
@@ -288,7 +294,7 @@ def handle_inputs(
288
  top_p: float = 1.0,
289
  top_k: int = 20,
290
  num_beams: int = 1,
291
- model_backend: str = "Default",
292
  ):
293
  """
294
  Returns:
@@ -340,7 +346,7 @@ def handle_inputs(
340
 
341
  generated_ids = generate(
342
  prompt,
343
- model_backend=model_backend,
344
  max_new_tokens=max_new_tokens,
345
  min_new_tokens=min_new_tokens,
346
  do_sample=True,
@@ -395,21 +401,30 @@ def demo():
395
  with gr.Blocks() as ui:
396
  gr.Markdown(
397
  """\
398
- # Danbooru Tags Transformer Demo """
 
 
 
 
 
 
 
 
 
399
  )
400
 
401
  with gr.Row():
402
  with gr.Column():
403
 
404
- with gr.Group(
405
- visible=False,
406
- ):
407
- model_backend_radio = gr.Radio(
408
- label="Model backend",
409
- choices=list(MODEL_BACKEND_MAP.keys()),
410
- value="Default",
411
- interactive=True,
412
- )
413
 
414
  with gr.Group():
415
  rating_dropdown = gr.Dropdown(
@@ -663,7 +678,7 @@ def demo():
663
  top_p_slider,
664
  top_k_slider,
665
  num_beams_slider,
666
- model_backend_radio,
667
  ],
668
  outputs=[
669
  output_tags_natural,
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
+ from optimum.onnxruntime import ORTModelForCausalLM
9
 
10
  from huggingface_hub import login
11
 
 
17
  else "p1atdev/dart-v1-sft"
18
  )
19
  HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
20
+ MODEL_BACKEND = (
21
+ os.environ.get("MODEL_BACKEND")
22
+ if os.environ.get("MODEL_BACKEND") is not None
23
+ else "ONNX (quantized)"
24
+ )
25
 
26
  assert isinstance(MODEL_NAME, str)
27
+ assert isinstance(MODEL_BACKEND, str)
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(
30
  MODEL_NAME,
 
36
  MODEL_NAME,
37
  token=HF_READ_TOKEN,
38
  ),
39
+ "ort": ORTModelForCausalLM.from_pretrained(
40
+ MODEL_NAME,
41
+ ),
42
+ "ort_qantized": ORTModelForCausalLM.from_pretrained(
43
+ MODEL_NAME,
44
+ file_name="model_quantized.onnx",
45
+ ),
46
  }
47
 
48
  MODEL_BACKEND_MAP = {
49
  "Default": "default",
50
+ "ONNX (normal)": "ort",
51
+ "ONNX (quantized)": "ort_qantized",
52
  }
53
 
54
  try:
 
294
  top_p: float = 1.0,
295
  top_k: int = 20,
296
  num_beams: int = 1,
297
+ # model_backend: str = "Default",
298
  ):
299
  """
300
  Returns:
 
346
 
347
  generated_ids = generate(
348
  prompt,
349
+ model_backend=MODEL_BACKEND,
350
  max_new_tokens=max_new_tokens,
351
  min_new_tokens=min_new_tokens,
352
  do_sample=True,
 
401
  with gr.Blocks() as ui:
402
  gr.Markdown(
403
  """\
404
+ # Danbooru Tags Transformer Demo
405
+
406
+ Collection: [Dart (Danbooru Tags Transformer)](https://huggingface.co/collections/p1atdev/dart-danbooru-tags-transformer-65d687604ff57dc62ae40945)
407
+
408
+ Models:
409
+
410
+ - [p1atdev/dart-v1-sft](https://huggingface.co/p1atdev/dart-v1-sft)
411
+ - [p1atdev/dart-v1-base](https://huggingface.co/p1atdev/dart-v1-base)
412
+
413
+ """
414
  )
415
 
416
  with gr.Row():
417
  with gr.Column():
418
 
419
+ # with gr.Group(
420
+ # visible=False,
421
+ # ):
422
+ # model_backend_radio = gr.Radio(
423
+ # label="Model backend",
424
+ # choices=list(MODEL_BACKEND_MAP.keys()),
425
+ # value="Default",
426
+ # interactive=True,
427
+ # )
428
 
429
  with gr.Group():
430
  rating_dropdown = gr.Dropdown(
 
678
  top_p_slider,
679
  top_k_slider,
680
  num_beams_slider,
681
+ # model_backend_radio,
682
  ],
683
  outputs=[
684
  output_tags_natural,
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  torch==2.1.0
2
- transformers==4.38.0
 
 
 
1
  torch==2.1.0
2
+ accelerate==0.26.1
3
+ transformers==4.38.0
4
+ optimum[onnxruntime]==1.17.1