Tabish009 commited on
Commit
31484bc
1 Parent(s): 35c3fef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -1,36 +1,33 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- # Load the Biomistral 7b model and tokenizer
5
- model_name = "biomistral/Biomistral-7b"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
 
 
8
 
9
- # Define the text generation function
10
- def generate_text(prompt, max_length=500, num_return_sequences=1, temperature=0.7):
11
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
12
- output = model.generate(
13
- input_ids,
14
- max_length=max_length,
15
- num_return_sequences=num_return_sequences,
16
- temperature=temperature,
17
- pad_token_id=tokenizer.eos_token_id,
18
- )
19
- generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
20
- return generated_text
21
 
22
- # Streamlit app
23
  def main():
24
- st.title("Doctor Chatbot (Powered by Biomistral 7b)")
25
- st.write("Welcome to the Doctor Chatbot. Please describe your symptoms or ask a medical question, and I'll provide a response.")
26
 
27
- user_input = st.text_area("Enter your symptoms or question:")
28
-
29
- if user_input:
30
  with st.spinner("Generating response..."):
31
- generated_text = generate_text(user_input)
32
- st.write(generated_text[0])
33
 
34
  if __name__ == "__main__":
35
- main()
36
-
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ @st.cache_resource
5
+ def load_model_and_tokenizer():
6
+ model_name_or_path = "m42-health/med42-70b"
7
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto")
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
9
+ return model, tokenizer
10
 
11
+ def generate_response(prompt):
12
+ prompt_template = f'''
13
+ <|system|>: You are a helpful medical assistant created by M42 Health in the UAE.
14
+ <|prompter|>:{prompt}
15
+ <|assistant|>:
16
+ '''
17
+ input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
18
+ output = model.generate(inputs=input_ids, temperature=0.7, do_sample=True, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, max_new_tokens=512)
19
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
20
+ return response
 
 
21
 
 
22
  def main():
23
+ st.title("M42 Health Medical Assistant")
24
+ model, tokenizer = load_model_and_tokenizer()
25
 
26
+ prompt = st.text_area("Enter your medical query:")
27
+ if st.button("Submit"):
 
28
  with st.spinner("Generating response..."):
29
+ response = generate_response(prompt)
30
+ st.write(response)
31
 
32
  if __name__ == "__main__":
33
+ main()