Unggi commited on
Commit
5e7dc6f
โ€ข
1 Parent(s): a6e8258

first commit

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # make function using import pip to install torch
4
+ import pip
5
+ pip.main(['install', 'torch'])
6
+ pip.main(['install', 'transformers'])
7
+
8
+ import torch
9
+ import transformers
10
+
11
+
12
+ # saved_model
13
+ def load_model(model_path):
14
+ saved_data = torch.load(
15
+ model_path,
16
+ map_location="cpu"
17
+ )
18
+
19
+ bart_best = saved_data["model"]
20
+ train_config = saved_data["config"]
21
+ tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
22
+
23
+ ## Load weights.
24
+ model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
25
+ model.load_state_dict(bart_best)
26
+
27
+ return model, tokenizer
28
+
29
+
30
+ # main
31
+ def inference(prompt):
32
+ model_path = "./kobart-model-logical.pth"
33
+
34
+ model, tokenizer = load_model(
35
+ model_path=model_path
36
+ )
37
+
38
+ input_ids = tokenizer.encode(prompt)
39
+ input_ids = torch.tensor(input_ids)
40
+ input_ids = input_ids.unsqueeze(0)
41
+ output = model.generate(input_ids)
42
+ output = tokenizer.decode(output[0], skip_special_tokens=True)
43
+
44
+ return output
45
+
46
+
47
+ demo = gr.Interface(
48
+ fn=inference,
49
+ inputs="text",
50
+ outputs="text" #return ๊ฐ’
51
+ ).launch() # launch(share=True)๋ฅผ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€์—์„œ ์ ‘์† ๊ฐ€๋Šฅํ•œ ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋จ
52
+
53
+ demo.launch()