Unggi commited on
Commit
af60965
1 Parent(s): daf919b

add model_file

Browse files
Files changed (2) hide show
  1. bart_demo_gradio.py +9 -9
  2. kobart-model-logical.pth +3 -0
bart_demo_gradio.py CHANGED
@@ -3,19 +3,21 @@ import gradio as gr
3
  import torch
4
  import transformers
5
 
 
 
6
  # saved_model
7
- def load_model(model_path, config):
8
  saved_data = torch.load(
9
  model_path,
10
- map_location="cpu" if config.gpu_id < 0 else "cuda:%d" % config.gpu_id
11
  )
12
 
13
  bart_best = saved_data["model"]
14
  train_config = saved_data["config"]
15
- tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(config.pretrained_model_name)
16
 
17
  ## Load weights.
18
- model = transformers.BartForConditionalGeneration.from_pretrained(config.pretrained_model_name)
19
  model.load_state_dict(bart_best)
20
 
21
  return model, tokenizer
@@ -23,13 +25,10 @@ def load_model(model_path, config):
23
 
24
  # main
25
  def inference(prompt):
26
-
27
- config = define_argparser()
28
- model_path = config.model_fpath
29
 
30
  model, tokenizer = load_model(
31
- model_path=model_path,
32
- config=config
33
  )
34
 
35
  input_ids = tokenizer.encode(prompt)
@@ -40,6 +39,7 @@ def inference(prompt):
40
 
41
  return output
42
 
 
43
  demo = gr.Interface(
44
  fn=inference,
45
  inputs="text",
 
3
  import torch
4
  import transformers
5
 
6
+
7
+
8
  # saved_model
9
+ def load_model(model_path):
10
  saved_data = torch.load(
11
  model_path,
12
+ map_location="cpu"
13
  )
14
 
15
  bart_best = saved_data["model"]
16
  train_config = saved_data["config"]
17
+ tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
18
 
19
  ## Load weights.
20
+ model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
21
  model.load_state_dict(bart_best)
22
 
23
  return model, tokenizer
 
25
 
26
  # main
27
  def inference(prompt):
28
+ model_path = "./kobart-model-logical.pth"
 
 
29
 
30
  model, tokenizer = load_model(
31
+ model_path=model_path
 
32
  )
33
 
34
  input_ids = tokenizer.encode(prompt)
 
39
 
40
  return output
41
 
42
+
43
  demo = gr.Interface(
44
  fn=inference,
45
  inputs="text",
kobart-model-logical.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f108a09b663ed85f66e384b51922867ed9fab15b3b7de3a20c1f7389e4ffdeb4
3
+ size 496665407