Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import json | |
| import threading | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer | |
| # ============================================================ | |
| # MODEL SETUP | |
| # ============================================================ | |
| MODEL_ID = "augtoma/qCammel-13" | |
| # 4-bit quantization (saves GPU memory) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ============================================================ | |
| # MEMORY HANDLING | |
| # ============================================================ | |
| MEMORY_FILE = "chat_memory.json" | |
| # Load or initialize chat memory | |
| if os.path.exists(MEMORY_FILE): | |
| with open(MEMORY_FILE, "r") as f: | |
| try: | |
| chat_memory = json.load(f) | |
| except json.JSONDecodeError: | |
| chat_memory = [] | |
| else: | |
| chat_memory = [] | |
| def save_memory(history): | |
| """Save chat history persistently.""" | |
| with open(MEMORY_FILE, "w") as f: | |
| json.dump(history, f, indent=2) | |
| # ============================================================ | |
| # SYSTEM PROMPT (doctor personality) | |
| # ============================================================ | |
| SYSTEM_PROMPT = ( | |
| "You are Dr. Camel, a professional, empathetic, and helpful medical doctor. " | |
| "You will respond only when the patient speaks. " | |
| "Never start the conversation by yourself. " | |
| "Always reply as 'Doctor:' and never simulate the patient's responses. " | |
| "Your tone should be calm, supportive, and medically informative. " | |
| "If symptoms seem serious, politely suggest seeing a healthcare professional." | |
| ) | |
| # ============================================================ | |
| # CONVERSATION PROMPT BUILDER | |
| # ============================================================ | |
| def build_conversation_prompt(history): | |
| """Builds a memory-aware prompt (doctor only replies after patient).""" | |
| conversation = SYSTEM_PROMPT + "\n\n" | |
| for turn in history[-6:]: | |
| if turn["role"] == "user": | |
| conversation += f"Patient: {turn['content'].strip()}\n" | |
| elif turn["role"] == "assistant": | |
| conversation += f"Doctor: {turn['content'].strip()}\n" | |
| conversation += "Doctor:" | |
| return conversation | |
| # ============================================================ | |
| # TEXT GENERATION (STREAMING) | |
| # ============================================================ | |
| def generate_stream(history, max_new_tokens=512): | |
| prompt = build_conversation_prompt(history) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=1.05, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| streamer=streamer, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| partial = "" | |
| for new_text in streamer: | |
| partial += new_text | |
| yield partial | |
| # ============================================================ | |
| # RESPONSE LOGIC | |
| # ============================================================ | |
| def respond(user_message, history): | |
| if not user_message.strip(): | |
| return gr.update(), history | |
| # Prevent the bot from talking first | |
| if len(history) == 0 and "Doctor" in user_message: | |
| return gr.update(), history | |
| history.append({"role": "user", "content": user_message}) | |
| partial = "" | |
| for partial in generate_stream(history): | |
| yield history + [{"role": "assistant", "content": partial}], history | |
| history.append({"role": "assistant", "content": partial}) | |
| save_memory(history) | |
| yield history, history | |
| def clear_chat(): | |
| global chat_memory | |
| chat_memory = [] | |
| save_memory(chat_memory) | |
| return [], [] | |
| # ============================================================ | |
| # GRADIO UI | |
| # ============================================================ | |
| with gr.Blocks(title="🩺 Dr. Camel — Medical Chatbot", css=".footer {display:none;}") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🩺 Dr. Camel — AI Medical Assistant | |
| Ask about your symptoms or medical concerns, and Dr. Camel will respond with care and clarity. | |
| *(For demo purposes only — not real medical advice.)* | |
| """ | |
| ) | |
| chatbot = gr.Chatbot(type="messages", elem_id="chatbot", height=520, value=chat_memory) | |
| with gr.Row(): | |
| txt = gr.Textbox(show_label=False, placeholder="Describe your symptoms or ask a question...", lines=2) | |
| clear = gr.Button("🧹 Clear Chat") | |
| state = gr.State(chat_memory) | |
| txt.submit(respond, [txt, state], [chatbot, state]) | |
| clear.click(clear_chat, None, [chatbot, state]) | |
| gr.Markdown( | |
| "### ⚠️ Disclaimer: This chatbot does not replace a real medical consultation. " | |
| "Always seek professional medical help for health emergencies." | |
| ) | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |