Spaces:
Paused
Paused
File size: 2,298 Bytes
2e8ea21 08422e4 121448b 08422e4 121448b 2e8ea21 08422e4 c0c54c0 121448b c0c54c0 2e8ea21 121448b 2e8ea21 121448b 2e8ea21 121448b 2e8ea21 121448b 2e8ea21 c0c54c0 121448b c0c54c0 08422e4 121448b 2e8ea21 121448b 2e8ea21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# Model name and configuration
model_name = "ruslanmv/Medical-Llama3-8B"
device_map = "auto"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True,
use_cache=False,
device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Set pad_token_id to eos_token_id if None
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Define the chat template
chat_template = """<|im_start|>system
{system}
<|im_end|>
<|im_start|>user
{user}
<|im_end|>
<|im_start|>assistant
"""
tokenizer.chat_template = chat_template
# Function to generate a response
def askme(question):
sys_message = """
You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
"""
# Structure messages for the chat
messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]
# Apply the chat template
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Generate response
outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
# Decode and clean up the response
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "<|im_start|>assistant" in response_text:
response_text = response_text.split("<|im_start|>assistant")[-1].strip()
return response_text
# Example usage
question = """
I'm a 35-year-old male and for the past few months, I've been experiencing fatigue,
increased sensitivity to cold, and dry, itchy skin.
Could these symptoms be related to hypothyroidism?
If so, what steps should I take to get a proper diagnosis and discuss treatment options?
"""
print(askme(question))
|