arampacha commited on
Commit
6b904c9
1 Parent(s): 7e44eab

reorganize model loading

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -4,14 +4,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  # model_name = "flax-community/gpt-code-clippy-1.3B-apps-alldata"
5
  model_name = "flax-community/gpt-code-clippy-125M-apps-alldata"
6
 
7
- @st.cache(allow_output_mutation=True)
8
  def get_model():
9
- return AutoModelForCausalLM.from_pretrained(model_name)
10
-
11
- def get_tokenizer():
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  tokenizer.pad_token = tokenizer.eos_token
14
- return tokenizer
 
15
 
16
  def format_input(question, starter_code=""):
17
  answer_type = "\nUse Call-Based format\n" if starter_code else \
@@ -30,6 +29,7 @@ def generate_solution(model, tokenizer, question, starter_code="", temperature=1
30
  do_sample=True,
31
  top_p=0.95,
32
  pad_token_id=tokenizer.pad_token_id,
 
33
  early_stopping=True,
34
  temperature=temperature,
35
  num_beams=int(num_beams),
@@ -105,8 +105,7 @@ def run():
105
  )
106
 
107
  # main body
108
- model = get_model()
109
- tokenizer = get_tokenizer()
110
 
111
  question = st.text_input(
112
  "Problem: ",
@@ -121,7 +120,7 @@ def run():
121
  submit_button = st.button("Solve")
122
 
123
  if submit_button:
124
-
125
  output = generate_solution(model, tokenizer, question, starter_code, temperature, num_beams)
126
  st.code(output, language="python")
127
 
 
4
  # model_name = "flax-community/gpt-code-clippy-1.3B-apps-alldata"
5
  model_name = "flax-community/gpt-code-clippy-125M-apps-alldata"
6
 
7
+ @st.cache(allow_output_mutation=True, max_entries=1)
8
  def get_model():
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  tokenizer.pad_token = tokenizer.eos_token
12
+ return (model, tokenizer)
13
+
14
 
15
  def format_input(question, starter_code=""):
16
  answer_type = "\nUse Call-Based format\n" if starter_code else \
 
29
  do_sample=True,
30
  top_p=0.95,
31
  pad_token_id=tokenizer.pad_token_id,
32
+ eos_token_id=tokenizer.eos_token_id,
33
  early_stopping=True,
34
  temperature=temperature,
35
  num_beams=int(num_beams),
 
105
  )
106
 
107
  # main body
108
+ model, tokenizer = get_model()
 
109
 
110
  question = st.text_input(
111
  "Problem: ",
 
120
  submit_button = st.button("Solve")
121
 
122
  if submit_button:
123
+ st.text("Solution:")
124
  output = generate_solution(model, tokenizer, question, starter_code, temperature, num_beams)
125
  st.code(output, language="python")
126