File size: 1,920 Bytes
b44e38c
ad03d38
 
 
b44e38c
ad03d38
 
b44e38c
 
 
 
 
 
 
 
 
 
ad03d38
b44e38c
 
 
 
6cd8eab
b44e38c
 
ad03d38
 
b44e38c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad03d38
b44e38c
 
 
ad03d38
b44e38c
6cd8eab
b44e38c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Define the model and tokenizer
model_id = "Writer/Palmyra-Med-70B-32k"

@st.cache(allow_output_mutation=True)
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation="flash_attention_2",
    )
    return tokenizer, model

tokenizer, model = load_model()

# Define Streamlit app
st.title("Medical Query Model")

st.write(
    "You are interacting with a highly knowledgeable medical model. Enter your medical question below:"
)

user_input = st.text_area("Your Question")

if st.button("Get Response"):
    if user_input:
        # Prepare input for the model
        messages = [
            {
                "role": "system",
                "content": "You are a highly knowledgeable and experienced expert in the healthcare and biomedical field, possessing extensive medical knowledge and practical expertise.",
            },
            {
                "role": "user",
                "content": user_input,
            },
        ]

        input_ids = tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
        )

        gen_conf = {
            "max_new_tokens": 256,
            "eos_token_id": [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("")],
            "temperature": 0.0,
            "top_p": 0.9,
        }

        # Generate response
        with torch.no_grad():
            output_id = model.generate(input_ids, **gen_conf)

        output_text = tokenizer.decode(output_id[0][input_ids.shape[1]:], skip_special_tokens=True)

        st.write("Response:")
        st.write(output_text)
    else:
        st.warning("Please enter a question.")