maliozer commited on
Commit
dfa5f93
1 Parent(s): 52a4e98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import BioGptTokenizer, BioGptForCausalLM
4
+
5
+ model_names = [
6
+ "BioGPT",
7
+ "BioGPT-Large",
8
+ "BioGPT-QA-PubMedQA-BioGPT",
9
+ "BioGPT-QA-PubMEDQA-BioGPT-Large",
10
+ "BioGPT-RE-BC5CDR",
11
+ "BioGPT-RE-DDI",
12
+ "BioGPT-RE-DTI",
13
+ "BioGPT-DC-HoC"
14
+ ]
15
+
16
+ def load_model(model_name):
17
+ model_name_map = {
18
+ "BioGPT":"microsoft/biogpt",
19
+ "BioGPT-QA-PubMedQA-BioGPT":"microsoft/BioGPT-Large-PubMedQA"
20
+ }
21
+
22
+ tokenizer = BioGptTokenizer.from_pretrained(model_name_map[model_name])
23
+ model = BioGptForCausalLM.from_pretrained(model_name_map[model_name])
24
+ return tokenizer, model
25
+
26
+
27
+ def get_beam_output(sentence, selected_model, min_len,max_len, n_beams):
28
+ tokenizer, model = load_model(selected_model)
29
+ inputs = tokenizer(sentence, return_tensors="pt")
30
+ with torch.no_grad():
31
+ beam_output = model.generate(**inputs,
32
+ min_length=100,
33
+ max_length=1024,
34
+ num_beams=n_beams,
35
+ early_stopping=True
36
+ )
37
+ output=tokenizer.decode(beam_output[0], skip_special_tokens=True)
38
+ return output
39
+
40
+ inputs = [
41
+ gr.inputs.Textbox(label="prompt", lines=5, default="Bicalutamide"),
42
+ gr.Dropdown(model_names, value="microsoft/biogpt", label="selected_model"),
43
+ gr.inputs.Slider(1, 500, 1, default=100, label="min_len"),
44
+ gr.inputs.Slider(1, 2048, 1, default=1024, label="max_len"),
45
+ gr.inputs.Slider(1, 10, 1, default=5, label="num_beams")
46
+ ]
47
+ outputs = gr.outputs.Textbox(label="output")
48
+
49
+ iface = gr.Interface(
50
+ fn=get_beam_output,
51
+ inputs=inputs,
52
+ outputs=outputs,
53
+ examples=[["Bicalutamide"], ["Janus kinase 3 (JAK-3)"], ["Apricitabine"], ["Xylazine"], ["Psoralen"], ["CP-673451"]]
54
+ )
55
+
56
+ iface.launch(debug=True, enable_queue=True)