File size: 4,569 Bytes
94152ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline


class EndpointHandler:
    def __init__(self, model_path="djangodevloper/llama3-70b-4bit-medqa"):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True
            )
            self.pipeline = TextGenerationPipeline(
                model=self.model,
                tokenizer=self.tokenizer,
            )
        except Exception as e:
            raise RuntimeError(f"Failed to initialize model or tokenizer: {e}")

        # PROMPT FOR GENERAL USERS
        self.general_prompt = (
            "You are DoctusMind, a trustworthy and friendly medical AI assistant. "
            "Provide clear, easy-to-understand, and medically accurate answers to everyday health questions. "
            "Use simple language and suggest safe, evidence-informed home remedies when suitable. "
            "Be supportive and avoid technical jargon. Prioritize safety and clarity. "
            "If asked a non-medical question, politely respond with:\n"
            "`{\"not_medical_question\": true}`\n"
            "Format responses with bullet points, headers, or short paragraphs when helpful."
        )

        # PROMPT FOR PROFESSIONAL USERS
        self.professional_prompt = (
            "You are DoctusMind, a highly competent and articulate medical AI assistant for healthcare professionals. "
            "Provide concise, medically rigorous responses using appropriate clinical terminology, diagnostic language, "
            "and pathophysiological reasoning. Reference guidelines (e.g., WHO, CDC, NICE) where relevant. "
            "Always maintain a professional tone and format responses for quick clinical comprehension. "
            "If asked a non-medical question, reply with:\n"
            "`{\"not_medical_question\": true}`"
        )

        # PROMPT FOR CONVERSATION SUMMARY
        self.summary_prompt = (
            "Update the user’s running chat summary by incorporating the most recent messages. "
            "Preserve important context like health conditions, preferences, personal facts, "
            "or constraints. Keep the summary compact and in User: ...\\nBot: ... format. "
            "Omit small talk unless relevant."
        )

        # PROMPT FOR CONVERSATION HEADER
        self.header_prompt = (
            "Generate a short and meaningful header (max 50 characters) based on the conversation."
        )

    def __call__(self, data):
        try:
            user_input = data.get("inputs", "")
            user_type = data.get("user_type", "general").strip().lower()
            mode = data.get("mode", "chat").strip().lower()

            if not user_input:
                return {"error": "Missing 'inputs' in request."}

            # Pick system prompt
            if mode == "summary":
                system_prompt = self.summary_prompt
            elif mode == "header":
                system_prompt = self.header_prompt
            else:
                system_prompt = self.professional_prompt if user_type == "professional" else self.general_prompt

            # Compose prompt (remove unnecessary newlines)
            full_prompt = f"<|system|>{system_prompt}<|user|>{user_input}<|assistant|>"

            # Generate
            outputs = self.pipeline(
                full_prompt,
                max_new_tokens=600,  # Reduced to 600 for latency, still enough for full answers
                temperature=0.1,  # Low = faster and focused
                top_k=50,
                top_p=0.9,
                repetition_penalty=1.05,
                do_sample=False,  # Deterministic = less decoding time
                eos_token_id=[
                    self.tokenizer.eos_token_id,
                    self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
                ]
            )

            # Extract
            generated_text = outputs[0]["generated_text"]
            response = generated_text.split("<|assistant|>")[-1].strip()

            # Fallback if empty
            if not response:
                response = "Sorry, I couldn't generate a complete response. Try rephrasing."

            return {"generated_text": response}

        except Exception as e:
            return {"error": f"Inference error: {str(e)}"}