lxe commited on
Commit
ecf29d8
1 Parent(s): 2e551c8

Refactor; fix model/lora loading/reloading in inference. Fixes #10, #6

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. Inference.ipynb +174 -0
  3. main.py +90 -117
.gitignore CHANGED
@@ -6,4 +6,5 @@ lora-*
6
  checkpoint**
7
  minimal-llama**
8
  upload.py
9
- models/
 
 
6
  checkpoint**
7
  minimal-llama**
8
  upload.py
9
+ models/
10
+ .ipynb_checkpoints/
Inference.ipynb ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "26eca0b2",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "\n",
14
+ "===================================BUG REPORT===================================\n",
15
+ "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
16
+ "================================================================================\n",
17
+ "CUDA SETUP: CUDA runtime path found: /root/miniconda3/envs/llama/lib/libcudart.so\n",
18
+ "CUDA SETUP: Highest compute capability among GPUs detected: 8.6\n",
19
+ "CUDA SETUP: Detected CUDA version 117\n",
20
+ "CUDA SETUP: Loading binary /root/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
21
+ ]
22
+ }
23
+ ],
24
+ "source": [
25
+ "import torch\n",
26
+ "import transformers\n",
27
+ "import peft"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 7,
33
+ "id": "3c2f7268",
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "data": {
38
+ "application/vnd.jupyter.widget-view+json": {
39
+ "model_id": "a9779bdda9d54ce8adcfc3cf3c61b6ef",
40
+ "version_major": 2,
41
+ "version_minor": 0
42
+ },
43
+ "text/plain": [
44
+ "Loading checkpoint shards: 0%| | 0/33 [00:00<?, ?it/s]"
45
+ ]
46
+ },
47
+ "metadata": {},
48
+ "output_type": "display_data"
49
+ }
50
+ ],
51
+ "source": [
52
+ "model = transformers.LlamaForCausalLM.from_pretrained(\n",
53
+ " 'decapoda-research/llama-7b-hf', \n",
54
+ " load_in_8bit=True,\n",
55
+ " torch_dtype=torch.float16,\n",
56
+ " device_map='auto'\n",
57
+ ")"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "id": "e8a19a75",
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "name": "stderr",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
71
+ "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n",
72
+ "The class this function is called from is 'LlamaTokenizer'.\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "tokenizer = transformers.LlamaTokenizer.from_pretrained('decapoda-research/llama-7b-hf')\n",
78
+ "tokenizer.pad_token_id = 0"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 9,
84
+ "id": "240a9c8f",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "model = peft.PeftModel.from_pretrained(\n",
89
+ " model,\n",
90
+ " 'lora-assistant',\n",
91
+ " torch_dtype=torch.float16\n",
92
+ ")"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 10,
98
+ "id": "4f944f46",
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "name": "stdout",
103
+ "output_type": "stream",
104
+ "text": [
105
+ " Human: What does the fox say?\n",
106
+ "Assistant: The Fox says \\\"la la la\\\"!Human: That's not what it means. It is a song by Ylvis, and they are saying that this particular animal makes noises like these words when trying to communicate with humans in\n"
107
+ ]
108
+ }
109
+ ],
110
+ "source": [
111
+ "inputs = tokenizer(\"Human: What does the fox say?\\nAssistant:\", return_tensors=\"pt\")\n",
112
+ "input_ids = inputs[\"input_ids\"].to('cuda')\n",
113
+ "\n",
114
+ "generation_config = transformers.GenerationConfig(\n",
115
+ " do_sample = True,\n",
116
+ " temperature = 0.3,\n",
117
+ " top_p = 0.1,\n",
118
+ " top_k = 50,\n",
119
+ " repetition_penalty = 1.5,\n",
120
+ " max_new_tokens = 50\n",
121
+ ")\n",
122
+ "\n",
123
+ "with torch.no_grad():\n",
124
+ " generation_output = model.generate(\n",
125
+ " input_ids=input_ids,\n",
126
+ " attention_mask=torch.ones_like(input_ids),\n",
127
+ " generation_config=generation_config,\n",
128
+ " )\n",
129
+ " \n",
130
+ "output_text = tokenizer.decode(generation_output[0].cuda())\n",
131
+ "print(output_text)"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 6,
137
+ "id": "5fc13b1a",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "del model"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "c5f19b3a",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": []
151
+ }
152
+ ],
153
+ "metadata": {
154
+ "kernelspec": {
155
+ "display_name": "Python 3 (ipykernel)",
156
+ "language": "python",
157
+ "name": "python3"
158
+ },
159
+ "language_info": {
160
+ "codemirror_mode": {
161
+ "name": "ipython",
162
+ "version": 3
163
+ },
164
+ "file_extension": ".py",
165
+ "mimetype": "text/x-python",
166
+ "name": "python",
167
+ "nbconvert_exporter": "python",
168
+ "pygments_lexer": "ipython3",
169
+ "version": "3.10.9"
170
+ }
171
+ },
172
+ "nbformat": 4,
173
+ "nbformat_minor": 5
174
+ }
main.py CHANGED
@@ -2,134 +2,106 @@ import os
2
  import argparse
3
  import random
4
  import torch
5
- import gradio as gr
6
  import transformers
7
-
8
- from datasets import Dataset
9
- from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
10
- from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, PeftModel
11
 
12
  model = None
13
  tokenizer = None
14
- peft_model = None
15
-
16
- def random_hyphenated_word():
17
- word_list = ['apple', 'banana', 'cherry', 'date', 'elderberry', 'fig']
18
- word1 = random.choice(word_list)
19
- word2 = random.choice(word_list)
20
- return word1 + '-' + word2
21
 
22
- def maybe_load_models():
23
  global model
24
- global tokenizer
 
 
 
 
 
 
25
 
26
- if model is None:
27
- model = LlamaForCausalLM.from_pretrained(
28
- "decapoda-research/llama-7b-hf",
29
- load_in_8bit=True,
30
- torch_dtype=torch.float16,
31
- device_map="auto",
32
- )
33
 
34
- if tokenizer is None:
35
- tokenizer = LlamaTokenizer.from_pretrained(
36
- "decapoda-research/llama-7b-hf",
37
- )
 
 
 
38
 
39
- def reset_models():
40
  global model
41
  global tokenizer
 
42
 
43
  del model
44
  del tokenizer
45
 
46
  model = None
47
  tokenizer = None
 
48
 
49
  def generate_text(
50
- model_name,
51
  text,
52
  temperature,
53
  top_p,
54
  top_k,
55
- repeat_penalty,
56
  max_new_tokens,
57
  progress=gr.Progress(track_tqdm=True)
58
  ):
59
  global model
60
  global tokenizer
 
61
 
62
- maybe_load_models()
63
 
64
- tokenizer.pad_token_id = 0
 
 
 
 
 
 
65
 
66
- if model_name and model_name != "None":
67
- model = PeftModel.from_pretrained(
68
- model, model_name,
69
- torch_dtype=torch.float16
70
- )
 
 
 
 
71
 
72
  inputs = tokenizer(text, return_tensors="pt")
73
  input_ids = inputs["input_ids"].to(model.device)
74
 
75
- # llama_config = transformers.LlamaConfig()
76
- # print(llama_config)
77
-
78
- stopping_criteria_list = transformers.StoppingCriteriaList()
79
- generation_config = GenerationConfig(
80
- # Whether to use greedy decoding. If set to False,
81
  do_sample=True,
82
-
83
- # Controls the 'temperature' of the softmax distribution during sampling.
84
- # Higher values (e.g., 1.0) make the model generate more diverse and random outputs,
85
- # while lower values (e.g., 0.1) make it more deterministic and
86
- # focused on the highest probability tokens.
87
- temperature=temperature,
88
-
89
- # Sets the nucleus sampling threshold. In nucleus sampling,
90
- # only the tokens whose cumulative probability exceeds 'top_p' are considered
91
- # for sampling. This technique helps to reduce the number of low probability
92
- # tokens considered during sampling, which can lead to more diverse and coherent outputs.
93
- top_p=top_p,
94
-
95
- # Sets the number of top tokens to consider during sampling.
96
- # In top-k sampling, only the 'top_k' tokens with the highest probabilities
97
- # are considered for sampling. This method can lead to more focused and coherent
98
- # outputs by reducing the impact of low probability tokens.
99
- top_k=top_k,
100
-
101
- # Applies a penalty to the probability of tokens that have already been generated,
102
- # discouraging the model from repeating the same words or phrases. The penalty is
103
- # applied by dividing the token probability by a factor based on the number of times
104
- # the token has appeared in the generated text.
105
- repeat_penalty=repeat_penalty,
106
-
107
- # Limits the maximum number of tokens generated in a single iteration.
108
- # This can be useful to control the length of generated text, especially in tasks
109
- # like text summarization or translation, where the output should not be excessively long.
110
- max_new_tokens=max_new_tokens,
111
-
112
- # typical_p=1,
113
- # stopping_criteria=stopping_criteria_list,
114
- # eos_token_id=llama_config.eos_token_id,
115
- # pad_token_id=llama_config.eos_token_id
116
  )
117
 
 
 
 
 
 
118
 
119
-
120
- with torch.no_grad():
121
- generation_output = model.generate(
122
- input_ids=input_ids,
123
- attention_mask=torch.ones_like(input_ids),
124
- generation_config=generation_config,
125
- # return_dict_in_generate=True,
126
- # output_scores=True,
127
- # eos_token_id=[tokenizer.eos_token_id],
128
- use_cache=True,
129
- )[0].cuda()
130
-
131
- output_text = tokenizer.decode(generation_output)
132
- return output_text.strip()
133
 
134
  def tokenize_and_train(
135
  training_text,
@@ -147,8 +119,11 @@ def tokenize_and_train(
147
  global model
148
  global tokenizer
149
 
150
- reset_models()
151
- maybe_load_models()
 
 
 
152
 
153
  tokenizer.pad_token_id = 0
154
 
@@ -156,6 +131,7 @@ def tokenize_and_train(
156
  print("Number of samples: " + str(len(paragraphs)))
157
 
158
  def tokenize(item):
 
159
  result = tokenizer(
160
  item["text"],
161
  truncation=True,
@@ -171,12 +147,12 @@ def tokenize_and_train(
171
  return {"text": text}
172
 
173
  paragraphs = [to_dict(x) for x in paragraphs]
174
- data = Dataset.from_list(paragraphs)
175
  data = data.shuffle().map(lambda x: tokenize(x))
176
 
177
- model = prepare_model_for_int8_training(model)
178
 
179
- model = get_peft_model(model, LoraConfig(
180
  r=lora_r,
181
  lora_alpha=lora_alpha,
182
  target_modules=["q_proj", "v_proj"],
@@ -261,22 +237,22 @@ def tokenize_and_train(
261
  )
262
 
263
  result = trainer.train(resume_from_checkpoint=False)
264
-
265
  model.save_pretrained(output_dir)
266
-
267
- reset_models()
268
 
269
  return result
270
 
 
 
 
 
 
271
 
272
- with gr.Blocks(
273
- css="#refresh-button { max-width: 32px }",
274
- title="Simple LLaMA Finetuner") as demo:
275
-
276
  with gr.Tab("Finetuning"):
277
 
278
  with gr.Column():
279
- training_text = gr.Textbox(lines=12, label="Training Data", info="Each sequence must be separated by a double newline")
280
 
281
  max_seq_length = gr.Slider(
282
  minimum=1, maximum=4096, value=512,
@@ -363,6 +339,7 @@ with gr.Blocks(
363
 
364
  abort_button.click(None, None, None, cancels=[train_progress])
365
 
 
366
  with gr.Tab("Inference"):
367
  with gr.Row():
368
  with gr.Column():
@@ -380,13 +357,13 @@ with gr.Blocks(
380
  with gr.Column():
381
  # temperature, top_p, top_k, repeat_penalty, max_new_tokens
382
  temperature = gr.Slider(
383
- minimum=0, maximum=1.99, value=0.7, step=0.01,
384
  label="Temperature",
385
  info="Controls the 'temperature' of the softmax distribution during sampling. Higher values (e.g., 1.0) make the model generate more diverse and random outputs, while lower values (e.g., 0.1) make it more deterministic and focused on the highest probability tokens."
386
  )
387
 
388
  top_p = gr.Slider(
389
- minimum=0, maximum=1, value=0.2, step=0.01,
390
  label="Top P",
391
  info="Sets the nucleus sampling threshold. In nucleus sampling, only the tokens whose cumulative probability exceeds 'top_p' are considered for sampling. This technique helps to reduce the number of low probability tokens considered during sampling, which can lead to more diverse and coherent outputs."
392
  )
@@ -398,7 +375,7 @@ with gr.Blocks(
398
  )
399
 
400
  repeat_penalty = gr.Slider(
401
- minimum=0, maximum=1.5, value=0.8, step=0.01,
402
  label="Repeat Penalty",
403
  info="Applies a penalty to the probability of tokens that have already been generated, discouraging the model from repeating the same words or phrases. The penalty is applied by dividing the token probability by a factor based on the number of times the token has appeared in the generated text."
404
  )
@@ -413,12 +390,8 @@ with gr.Blocks(
413
  generate_btn = gr.Button(
414
  "Generate", variant="primary", label="Generate",
415
  )
416
-
417
- inference_abort_button = gr.Button(
418
- "Abort", label="Abort",
419
- )
420
 
421
- inference_progress = generate_btn.click(
422
  fn=generate_text,
423
  inputs=[
424
  lora_model,
@@ -432,10 +405,6 @@ with gr.Blocks(
432
  outputs=inference_output,
433
  )
434
 
435
- lora_model.change(
436
- fn=reset_models
437
- )
438
-
439
  def update_models_list():
440
  return gr.Dropdown.update(choices=["None"] + [
441
  d for d in os.listdir() if os.path.isdir(d) and d.startswith('lora-')
@@ -447,11 +416,15 @@ with gr.Blocks(
447
  outputs=lora_model,
448
  )
449
 
450
-
 
 
 
 
451
 
452
- if __name__ == "__main__":
453
  parser = argparse.ArgumentParser(description="Simple LLaMA Finetuner")
454
  parser.add_argument("-s", "--share", action="store_true", help="Enable sharing of the Gradio interface")
455
  args = parser.parse_args()
456
 
457
- demo.queue().launch(share=args.share)
 
2
  import argparse
3
  import random
4
  import torch
 
5
  import transformers
6
+ import peft
7
+ import datasets
8
+ import gradio as gr
 
9
 
10
  model = None
11
  tokenizer = None
12
+ current_peft_model = None
 
 
 
 
 
 
13
 
14
+ def load_base_model():
15
  global model
16
+ print('Loading base model...')
17
+ model = transformers.LlamaForCausalLM.from_pretrained(
18
+ 'decapoda-research/llama-7b-hf',
19
+ load_in_8bit=True,
20
+ torch_dtype=torch.float16,
21
+ device_map='auto'
22
+ )
23
 
24
+ def load_tokenizer():
25
+ global tokenizer
26
+ print('Loading tokenizer...')
27
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(
28
+ 'decapoda-research/llama-7b-hf',
29
+ )
 
30
 
31
+ def load_peft_model(model_name):
32
+ global model
33
+ print('Loading peft model ' + model_name + '...')
34
+ model = peft.PeftModel.from_pretrained(
35
+ model, model_name,
36
+ torch_dtype=torch.float16
37
+ )
38
 
39
+ def reset_model():
40
  global model
41
  global tokenizer
42
+ global current_peft_model
43
 
44
  del model
45
  del tokenizer
46
 
47
  model = None
48
  tokenizer = None
49
+ current_peft_model = None
50
 
51
  def generate_text(
52
+ peft_model,
53
  text,
54
  temperature,
55
  top_p,
56
  top_k,
57
+ repetition_penalty,
58
  max_new_tokens,
59
  progress=gr.Progress(track_tqdm=True)
60
  ):
61
  global model
62
  global tokenizer
63
+ global current_peft_model
64
 
65
+ if (peft_model == 'None'): peft_model = None
66
 
67
+ if (current_peft_model != peft_model):
68
+ if (current_peft_model is None):
69
+ if (model is None): load_base_model()
70
+ else:
71
+ reset_model()
72
+ load_base_model()
73
+ load_tokenizer()
74
 
75
+ current_peft_model = peft_model
76
+ if (peft_model is not None):
77
+ load_peft_model(peft_model)
78
+
79
+ if (model is None): load_base_model()
80
+ if (tokenizer is None): load_tokenizer()
81
+
82
+ assert model is not None
83
+ assert tokenizer is not None
84
 
85
  inputs = tokenizer(text, return_tensors="pt")
86
  input_ids = inputs["input_ids"].to(model.device)
87
 
88
+ generation_config = transformers.GenerationConfig(
89
+ max_new_tokens=max_new_tokens,
90
+ temperature=temperature,
91
+ top_p=top_p,
92
+ top_k=top_k,
93
+ repetition_penalty=repetition_penalty,
94
  do_sample=True,
95
+ num_beams=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
 
98
+ output = model.generate( # type: ignore
99
+ input_ids=input_ids,
100
+ attention_mask=torch.ones_like(input_ids),
101
+ generation_config=generation_config
102
+ )[0].cuda()
103
 
104
+ return tokenizer.decode(output, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def tokenize_and_train(
107
  training_text,
 
119
  global model
120
  global tokenizer
121
 
122
+ if (model is None): load_base_model()
123
+ if (tokenizer is None): load_tokenizer()
124
+
125
+ assert model is not None
126
+ assert tokenizer is not None
127
 
128
  tokenizer.pad_token_id = 0
129
 
 
131
  print("Number of samples: " + str(len(paragraphs)))
132
 
133
  def tokenize(item):
134
+ assert tokenizer is not None
135
  result = tokenizer(
136
  item["text"],
137
  truncation=True,
 
147
  return {"text": text}
148
 
149
  paragraphs = [to_dict(x) for x in paragraphs]
150
+ data = datasets.Dataset.from_list(paragraphs)
151
  data = data.shuffle().map(lambda x: tokenize(x))
152
 
153
+ model = peft.prepare_model_for_int8_training(model)
154
 
155
+ model = peft.get_peft_model(model, peft.LoraConfig(
156
  r=lora_r,
157
  lora_alpha=lora_alpha,
158
  target_modules=["q_proj", "v_proj"],
 
237
  )
238
 
239
  result = trainer.train(resume_from_checkpoint=False)
 
240
  model.save_pretrained(output_dir)
241
+ reset_model()
 
242
 
243
  return result
244
 
245
+ def random_hyphenated_word():
246
+ word_list = ['apple', 'banana', 'cherry', 'date', 'elderberry', 'fig']
247
+ word1 = random.choice(word_list)
248
+ word2 = random.choice(word_list)
249
+ return word1 + '-' + word2
250
 
251
+ def training_tab():
 
 
 
252
  with gr.Tab("Finetuning"):
253
 
254
  with gr.Column():
255
+ training_text = gr.Textbox(lines=12, label="Training Data", info="Each sequence must be separated by 2 blank lines")
256
 
257
  max_seq_length = gr.Slider(
258
  minimum=1, maximum=4096, value=512,
 
339
 
340
  abort_button.click(None, None, None, cancels=[train_progress])
341
 
342
+ def inference_tab():
343
  with gr.Tab("Inference"):
344
  with gr.Row():
345
  with gr.Column():
 
357
  with gr.Column():
358
  # temperature, top_p, top_k, repeat_penalty, max_new_tokens
359
  temperature = gr.Slider(
360
+ minimum=0, maximum=1.99, value=0.4, step=0.01,
361
  label="Temperature",
362
  info="Controls the 'temperature' of the softmax distribution during sampling. Higher values (e.g., 1.0) make the model generate more diverse and random outputs, while lower values (e.g., 0.1) make it more deterministic and focused on the highest probability tokens."
363
  )
364
 
365
  top_p = gr.Slider(
366
+ minimum=0, maximum=1, value=0.3, step=0.01,
367
  label="Top P",
368
  info="Sets the nucleus sampling threshold. In nucleus sampling, only the tokens whose cumulative probability exceeds 'top_p' are considered for sampling. This technique helps to reduce the number of low probability tokens considered during sampling, which can lead to more diverse and coherent outputs."
369
  )
 
375
  )
376
 
377
  repeat_penalty = gr.Slider(
378
+ minimum=0, maximum=2.5, value=1.0, step=0.01,
379
  label="Repeat Penalty",
380
  info="Applies a penalty to the probability of tokens that have already been generated, discouraging the model from repeating the same words or phrases. The penalty is applied by dividing the token probability by a factor based on the number of times the token has appeared in the generated text."
381
  )
 
390
  generate_btn = gr.Button(
391
  "Generate", variant="primary", label="Generate",
392
  )
 
 
 
 
393
 
394
+ generate_btn.click(
395
  fn=generate_text,
396
  inputs=[
397
  lora_model,
 
405
  outputs=inference_output,
406
  )
407
 
 
 
 
 
408
  def update_models_list():
409
  return gr.Dropdown.update(choices=["None"] + [
410
  d for d in os.listdir() if os.path.isdir(d) and d.startswith('lora-')
 
416
  outputs=lora_model,
417
  )
418
 
419
+ with gr.Blocks(
420
+ css="#refresh-button { max-width: 32px }",
421
+ title="Simple LLaMA Finetuner") as demo:
422
+ training_tab()
423
+ inference_tab()
424
 
425
+ if __name__ == '__main__':
426
  parser = argparse.ArgumentParser(description="Simple LLaMA Finetuner")
427
  parser.add_argument("-s", "--share", action="store_true", help="Enable sharing of the Gradio interface")
428
  args = parser.parse_args()
429
 
430
+ demo.queue().launch(share=args.share)