mipatov commited on
Commit
a575d46
1 Parent(s): a567626

get model fix

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -7,14 +7,14 @@ import re
7
  from PIL import Image
8
 
9
 
10
- def get_model_gpt(model_name):
11
- tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
12
  model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
13
  model.eval()
14
  return model, tokenizer
15
 
16
- def get_model_t5(model_name):
17
- tokenizer = transformers.T5Tokenizer.from_pretrained(model_name)
18
  model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
19
  model.eval()
20
  return model, tokenizer
 
7
  from PIL import Image
8
 
9
 
10
+ def get_model_gpt(model_name,tokenizer_name):
11
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained(tokenizer_name)
12
  model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
13
  model.eval()
14
  return model, tokenizer
15
 
16
+ def get_model_t5(model_name,tokenizer_name):
17
+ tokenizer = transformers.T5Tokenizer.from_pretrained(tokenizer_name)
18
  model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
19
  model.eval()
20
  return model, tokenizer