File size: 2,298 Bytes
2e8ea21
08422e4
121448b
 
08422e4
121448b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e8ea21
08422e4
c0c54c0
 
 
 
 
121448b
 
 
 
 
 
 
 
 
 
c0c54c0
2e8ea21
121448b
2e8ea21
 
121448b
 
2e8ea21
 
121448b
2e8ea21
 
121448b
 
2e8ea21
 
c0c54c0
121448b
c0c54c0
 
 
08422e4
121448b
 
 
2e8ea21
 
121448b
 
2e8ea21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Model name and configuration
model_name = "ruslanmv/Medical-Llama3-8B"
device_map = "auto"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    trust_remote_code=True,
    use_cache=False,
    device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Set pad_token_id to eos_token_id if None
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Define the chat template
chat_template = """<|im_start|>system
{system}
<|im_end|>
<|im_start|>user
{user}
<|im_end|>
<|im_start|>assistant
"""
tokenizer.chat_template = chat_template

# Function to generate a response
def askme(question):
    sys_message = """ 
    You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
    provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
    """   
    # Structure messages for the chat
    messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]
    
    # Apply the chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    # Generate response
    outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
    
    # Decode and clean up the response
    response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "<|im_start|>assistant" in response_text:
        response_text = response_text.split("<|im_start|>assistant")[-1].strip()
    return response_text

# Example usage
question = """
I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, 
increased sensitivity to cold, and dry, itchy skin. 
Could these symptoms be related to hypothyroidism? 
If so, what steps should I take to get a proper diagnosis and discuss treatment options?
"""
print(askme(question))