File size: 6,427 Bytes
6bcba58
 
 
 
 
 
b783cda
 
 
 
6b9591f
 
36bda53
 
 
 
6bcba58
b783cda
6bcba58
 
 
5103369
6bcba58
 
38cf703
 
aa08e22
6bcba58
 
6b02e11
 
aa08e22
38cf703
aa08e22
6b02e11
454b0bf
 
38cf703
 
 
454b0bf
 
 
8861375
e17f0b6
 
 
 
 
 
 
 
 
 
 
 
 
36bda53
 
 
 
 
 
 
b783cda
 
b9c7951
 
 
 
6bcba58
1c3c739
 
8739bcf
1c3c739
 
ba89683
b188633
 
ba89683
b188633
 
ba89683
b188633
35b74b4
8739bcf
 
8e6bacf
8739bcf
242b368
8739bcf
ba89683
b188633
 
242b368
 
b188633
 
 
 
 
 
242b368
b188633
 
 
 
 
 
 
 
 
 
6bcba58
 
 
31c147a
6bcba58
e17f0b6
6bcba58
96ac3aa
b9c7951
6bcba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38cf703
 
 
 
 
6bcba58
 
 
 
6b02e11
 
6bcba58
aa08e22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import os
import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import transformers
import torch
from peft import PeftModel, PeftConfig
import os
from transformers import (
    BitsAndBytesConfig,
    pipeline,
)

access_token = os.getenv('HF_TOKEN')
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)


DESCRIPTION = '''
<div>
<h1 style="text-align: center;">PhysicianAI</h1>

</div>
'''

LICENSE = """
<p/>
---
Built with Mistral 7B Model
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <img src="https://www.thesmartcityjournal.com/images/Imagenes-Articulos/2023_09_Septiembre/AI_in_healthcare.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;  "> 
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">PhysicianAI</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">I am a Medicle CHatBot...</p>
</div>
"""


css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

compute_dtype = getattr(torch, "float16")
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)

#config = PeftConfig.from_pretrained("physician-ai/mistral-finetuned1",use_auth_token=access_token)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B",use_auth_token=access_token,quantization_config=quant_config,device_map="auto")
model = PeftModel.from_pretrained(model, "physician-ai/llama3-8b-finetuned",use_auth_token=access_token,quantization_config=quant_config,device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("physician-ai/llama3-8b-finetuned",use_auth_token=access_token)
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=4096, temperature=0.8, top_p=0.95, repetition_penalty=1.15)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("")
]

def generate_response(input_ids, generate_kwargs, output_queue):
    try:
        output = model.generate(**generate_kwargs)
        output_queue.append(output)
    except Exception as e:
        print(f"Error during generation: {e}")
        output_queue.append(None)

@spaces.GPU(duration=120)
def chat_llama3_8b(message, history, temperature=0.95, max_new_tokens=4096):
    # System prompt
    system_prompt = "You are an advanced medical assistant trained on extensive medical literature, including the highly authoritative Harrison Principles of Internal Medicine. Your role is to provide detailed and accurate medical information, diagnosis assistance, and guidance based on evidence and clinical best practices. Answer questions clearly, with a focus on providing practical and applicable medical advice. When relevant, cite specific chapters or sections from Harrison s Principles of Internal Medicine to support your responses."
    
    # Prepare conversation context without including the system prompt in the final input
    conversation = [system_prompt, message] + [msg for pair in history for msg in pair]
    inputs = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
    
    generate_kwargs = {
        "input_ids": inputs[:, 1:],  # Exclude the system prompt from generation input
        "max_length": inputs.shape[1] - 1 + max_new_tokens,  # Adjust max length accordingly
        "temperature": temperature,
        "num_return_sequences": 1
    }

    # Thread for generating model response
    output_queue = []
    response_thread = Thread(target=generate_response, args=(inputs[:, 1:], generate_kwargs, output_queue))
    response_thread.start()
    response_thread.join()  # Wait for the thread to complete

    # Retrieve the output from the queue
    if output_queue:
        output = output_queue[0]
        if output is not None:
            generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
            return generated_text
    return "An error occurred during text generation."
        

# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')

with gr.Blocks(fill_height=True, css=css) as demo:
    
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Finetuned LLAMA 3 8B Model", elem_id="duplicate-button")
    gr.ChatInterface(
        fn=chat_llama3_8b,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False ),
            ],
        examples=[
            ["I've been experiencing persistent headaches, nausea, and sensitivity to light. What could be causing this?"],
            ["Based on my diagnosis of type 2 diabetes, what are the recommended treatment options? Should I consider medication, lifestyle changes, or both?"],
            ["I'm currently taking lisinopril for hypertension and atorvastatin for high cholesterol. Are there any potential interactions or side effects I should be aware of if I start taking ibuprofen for occasional pain relief?"],
            ["I'm in my early 40s and have a family history of heart disease. What are some preventive measures I can take to lower my risk, besides regular exercise and a healthy diet?"],
            ["Can you provide information on rheumatoid arthritis, including common symptoms, diagnostic tests, and available treatment options?"]
            ],
        cache_examples=False,
                     )
    
    gr.Markdown(LICENSE)
    
if __name__ == "__main__":
    demo.launch()