Unggi commited on
Commit
a9f32fb
β€’
1 Parent(s): af60965

change app

Browse files
Files changed (2) hide show
  1. app.py +46 -4
  2. bart_demo_gradio.py +0 -49
app.py CHANGED
@@ -1,7 +1,49 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ import torch
4
+ import transformers
5
 
6
+
7
+
8
+ # saved_model
9
+ def load_model(model_path):
10
+ saved_data = torch.load(
11
+ model_path,
12
+ map_location="cpu"
13
+ )
14
+
15
+ bart_best = saved_data["model"]
16
+ train_config = saved_data["config"]
17
+ tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
18
+
19
+ ## Load weights.
20
+ model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
21
+ model.load_state_dict(bart_best)
22
+
23
+ return model, tokenizer
24
+
25
+
26
+ # main
27
+ def inference(prompt):
28
+ model_path = "./kobart-model-logical.pth"
29
+
30
+ model, tokenizer = load_model(
31
+ model_path=model_path
32
+ )
33
+
34
+ input_ids = tokenizer.encode(prompt)
35
+ input_ids = torch.tensor(input_ids)
36
+ input_ids = input_ids.unsqueeze(0)
37
+ output = model.generate(input_ids)
38
+ output = tokenizer.decode(output[0], skip_special_tokens=True)
39
+
40
+ return output
41
+
42
+
43
+ demo = gr.Interface(
44
+ fn=inference,
45
+ inputs="text",
46
+ outputs="text" #return κ°’
47
+ ).launch(share=True) # launch(share=True)λ₯Ό μ„€μ •ν•˜λ©΄ μ™ΈλΆ€μ—μ„œ 접속 κ°€λŠ₯ν•œ 링크가 생성됨
48
+
49
+ demo.launch()
bart_demo_gradio.py DELETED
@@ -1,49 +0,0 @@
1
- import gradio as gr
2
-
3
- import torch
4
- import transformers
5
-
6
-
7
-
8
- # saved_model
9
- def load_model(model_path):
10
- saved_data = torch.load(
11
- model_path,
12
- map_location="cpu"
13
- )
14
-
15
- bart_best = saved_data["model"]
16
- train_config = saved_data["config"]
17
- tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
18
-
19
- ## Load weights.
20
- model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
21
- model.load_state_dict(bart_best)
22
-
23
- return model, tokenizer
24
-
25
-
26
- # main
27
- def inference(prompt):
28
- model_path = "./kobart-model-logical.pth"
29
-
30
- model, tokenizer = load_model(
31
- model_path=model_path
32
- )
33
-
34
- input_ids = tokenizer.encode(prompt)
35
- input_ids = torch.tensor(input_ids)
36
- input_ids = input_ids.unsqueeze(0)
37
- output = model.generate(input_ids)
38
- output = tokenizer.decode(output[0], skip_special_tokens=True)
39
-
40
- return output
41
-
42
-
43
- demo = gr.Interface(
44
- fn=inference,
45
- inputs="text",
46
- outputs="text" #return κ°’
47
- ).launch(share=True) # launch(share=True)λ₯Ό μ„€μ •ν•˜λ©΄ μ™ΈλΆ€μ—μ„œ 접속 κ°€λŠ₯ν•œ 링크가 생성됨
48
-
49
- demo.launch()