lxe commited on
Commit
d7cdbfb
1 Parent(s): 8dcad5e

Reload models better

Browse files
Files changed (1) hide show
  1. main.py +17 -5
main.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import argparse
 
3
  import torch
4
  import gradio as gr
5
  import transformers
@@ -12,6 +13,12 @@ model = None
12
  tokenizer = None
13
  peft_model = None
14
 
 
 
 
 
 
 
15
  def maybe_load_models():
16
  global model
17
  global tokenizer
@@ -29,8 +36,6 @@ def maybe_load_models():
29
  "decapoda-research/llama-7b-hf",
30
  )
31
 
32
- return model, tokenizer
33
-
34
  def reset_models():
35
  global model
36
  global tokenizer
@@ -51,7 +56,10 @@ def generate_text(
51
  max_new_tokens,
52
  progress=gr.Progress(track_tqdm=True)
53
  ):
54
- model, tokenizer = maybe_load_models()
 
 
 
55
 
56
  if model_name and model_name != "None":
57
  model = PeftModel.from_pretrained(
@@ -123,7 +131,11 @@ def tokenize_and_train(
123
  model_name,
124
  progress=gr.Progress(track_tqdm=True)
125
  ):
126
- model, tokenizer = maybe_load_models()
 
 
 
 
127
 
128
  tokenizer.pad_token_id = 0
129
 
@@ -302,7 +314,7 @@ with gr.Blocks(css="#refresh-button { max-width: 32px }") as demo:
302
 
303
  with gr.Column():
304
  model_name = gr.Textbox(
305
- lines=1, label="LoRA Model Name", value=""
306
  )
307
 
308
  with gr.Row():
 
1
  import os
2
  import argparse
3
+ import random
4
  import torch
5
  import gradio as gr
6
  import transformers
 
13
  tokenizer = None
14
  peft_model = None
15
 
16
+ def random_hyphenated_word():
17
+ word_list = ['apple', 'banana', 'cherry', 'date', 'elderberry', 'fig']
18
+ word1 = random.choice(word_list)
19
+ word2 = random.choice(word_list)
20
+ return word1 + '-' + word2
21
+
22
  def maybe_load_models():
23
  global model
24
  global tokenizer
 
36
  "decapoda-research/llama-7b-hf",
37
  )
38
 
 
 
39
  def reset_models():
40
  global model
41
  global tokenizer
 
56
  max_new_tokens,
57
  progress=gr.Progress(track_tqdm=True)
58
  ):
59
+ global model
60
+ global tokenizer
61
+
62
+ maybe_load_models()
63
 
64
  if model_name and model_name != "None":
65
  model = PeftModel.from_pretrained(
 
131
  model_name,
132
  progress=gr.Progress(track_tqdm=True)
133
  ):
134
+ global model
135
+ global tokenizer
136
+
137
+ reset_models()
138
+ maybe_load_models()
139
 
140
  tokenizer.pad_token_id = 0
141
 
 
314
 
315
  with gr.Column():
316
  model_name = gr.Textbox(
317
+ lines=1, label="LoRA Model Name", value=random_hyphenated_word()
318
  )
319
 
320
  with gr.Row():