caliex commited on
Commit
2263c30
1 Parent(s): 789a19a

Upload test-st.py

Browse files
Files changed (1) hide show
  1. test-st.py +60 -0
test-st.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+
5
+
6
+ model_id = "Narrativaai/BioGPT-Large-finetuned-chatdoctor"
7
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")
8
+ model = AutoModelForCausalLM.from_pretrained(model_id)
9
+
10
+
11
+ def answer_question(prompt, temperature=0.1, top_p=0.75, top_k=40, num_beams=2, **kwargs):
12
+ inputs = tokenizer(prompt, return_tensors="pt")
13
+ input_ids = inputs["input_ids"].to("cpu")
14
+ attention_mask = inputs["attention_mask"].to("cpu")
15
+ generation_config = GenerationConfig(
16
+ temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, **kwargs
17
+ )
18
+ with torch.no_grad():
19
+ generation_output = model.generate(
20
+ input_ids=input_ids,
21
+ attention_mask=attention_mask,
22
+ generation_config=generation_config,
23
+ return_dict_in_generate=True,
24
+ output_scores=True,
25
+ max_new_tokens=512,
26
+ eos_token_id=tokenizer.eos_token_id,
27
+ )
28
+ s = generation_output.sequences[0]
29
+ output = tokenizer.decode(s, skip_special_tokens=True)
30
+ return output.split(" Response:")[1]
31
+
32
+
33
+ st.set_page_config(page_title="Medical Chat Bot", page_icon=":ambulance:", layout="wide")
34
+ st.title("Medical Chat Bot")
35
+ st.caption("Talk your way to better health")
36
+
37
+ with open("ui/sidebar.md", "r") as sidebar_file:
38
+ sidebar_content = sidebar_file.read()
39
+
40
+ with open("ui/styles.md", "r") as styles_file:
41
+ styles_content = styles_file.read()
42
+
43
+ # Display the DDL for the selected table
44
+ st.sidebar.markdown(sidebar_content)
45
+
46
+ st.write(styles_content, unsafe_allow_html=True)
47
+
48
+
49
+ st.write("Please enter your question below:")
50
+
51
+ # get user input
52
+ user_input = st.text_input("You: ")
53
+
54
+ if user_input:
55
+ # generate response
56
+ bot_response = answer_question(f"Input: {user_input}\nResponse:")
57
+ st.write("")
58
+ st.write("Bot:", bot_response)
59
+
60
+