|
|
import os
|
|
|
import gc
|
|
|
import torch
|
|
|
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
|
|
|
MAX_NEW_TOKENS = 200
|
|
|
TEMPERATURE = 0.5
|
|
|
TOP_K = 50
|
|
|
REPETITION_PENALTY = 1.1
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
print(f"Loading model from {MODEL_PATH} on {device}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
MODEL_PATH,
|
|
|
device_map="auto",
|
|
|
torch_dtype=torch.float16,
|
|
|
low_cpu_mem_usage=True
|
|
|
)
|
|
|
|
|
|
generator = model.generate
|
|
|
print("✅ ChatDoctor model loaded successfully!\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StopOnTokens(StoppingCriteria):
|
|
|
def __init__(self, stop_ids):
|
|
|
self.stop_ids = stop_ids
|
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs):
|
|
|
for stop_id_seq in self.stop_ids:
|
|
|
if len(stop_id_seq) == 1:
|
|
|
if input_ids[0][-1] == stop_id_seq[0]:
|
|
|
return True
|
|
|
else:
|
|
|
if len(input_ids[0]) >= len(stop_id_seq):
|
|
|
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_response(user_input):
|
|
|
global history
|
|
|
human_invitation = "Patient: "
|
|
|
doctor_invitation = "ChatDoctor: "
|
|
|
|
|
|
|
|
|
history.append(human_invitation + user_input)
|
|
|
|
|
|
|
|
|
prompt = "\n".join(history) + "\n" + doctor_invitation
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
|
|
|
|
|
|
|
|
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
|
|
|
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
|
|
|
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output_ids = generator(
|
|
|
input_ids,
|
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
|
do_sample=True,
|
|
|
temperature=TEMPERATURE,
|
|
|
top_k=TOP_K,
|
|
|
repetition_penalty=REPETITION_PENALTY,
|
|
|
stopping_criteria=stopping_criteria,
|
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
|
eos_token_id=tokenizer.eos_token_id
|
|
|
)
|
|
|
|
|
|
|
|
|
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
response = full_output[len(prompt):].strip()
|
|
|
|
|
|
|
|
|
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
|
|
|
if stop_word in response:
|
|
|
response = response.split(stop_word)[0].strip()
|
|
|
break
|
|
|
|
|
|
|
|
|
response = response.strip()
|
|
|
|
|
|
history.append(doctor_invitation + response)
|
|
|
|
|
|
|
|
|
del input_ids, output_ids
|
|
|
gc.collect()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("\n=== ChatDoctor is ready! ===")
|
|
|
print("You (the human) = Patient ")
|
|
|
print("AI = ChatDoctor")
|
|
|
print("Type 'exit' or 'quit' to end the chat.\n")
|
|
|
|
|
|
print("ChatDoctor: Hi there! How can I help you today?\n")
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
user_input = input("Patient: ").strip()
|
|
|
if user_input.lower() in ["exit", "quit"]:
|
|
|
print("ChatDoctor: Take care! Goodbye ")
|
|
|
break
|
|
|
|
|
|
if not user_input:
|
|
|
continue
|
|
|
|
|
|
response = get_response(user_input)
|
|
|
print("ChatDoctor:", response, "\n")
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nChatDoctor: Take care! Goodbye")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
print("Please try again.\n") |