medbridge / app.py
nexusbert's picture
push
2229032
from uagents import Agent, Context, Protocol
from uagents.setup import fund_agent_if_low
from datetime import datetime
from uuid import uuid4
import os
import uvicorn
from uagents_core.contrib.protocols.chat import (
ChatAcknowledgement,
ChatMessage,
EndSessionContent,
StartSessionContent,
TextContent,
chat_protocol_spec,
)
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch
import io
import requests
try:
from hyperon import MeTTa, GroundedAtom, E
METTA_AVAILABLE = True
except ImportError:
METTA_AVAILABLE = False
AGENT_NAME = os.getenv("AGENT_NAME")
AGENT_SEED = os.getenv("AGENT_SEED")
agent = Agent(
name=AGENT_NAME,
seed=AGENT_SEED,
port=7860,
mailbox = True,
endpoint=["https://nexusbert-medbridge.hf.space/submit"],
)
fund_agent_if_low(agent.wallet.address())
text_model_name = "microsoft/BioGPT-Large"
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
text_model = AutoModelForCausalLM.from_pretrained(text_model_name)
text_pipeline = pipeline(
"text-generation",
model=text_model,
tokenizer=text_tokenizer,
device=0 if torch.cuda.is_available() else -1
)
vision_model_name = "google/vit-base-patch16-224"
vision_feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name)
vision_model = ViTForImageClassification.from_pretrained(vision_model_name)
if torch.cuda.is_available():
vision_model = vision_model.to('cuda')
if METTA_AVAILABLE and os.getenv("ENABLE_METTA", "").lower() == "true":
# Initialize MeTTa without loading stdlib to avoid permission issues in restricted environments
try:
try:
metta = MeTTa(stdlib=False) # preferred if supported by current hyperon version
except TypeError:
# Fallback for older versions without the stdlib flag
metta = MeTTa()
medical_knowledge = """
(symptom diabetes frequent-urination)
(symptom diabetes increased-thirst)
(symptom diabetes fatigue)
(symptom diabetes blurred-vision)
(symptom diabetes slow-healing)
(symptom hypertension headache)
(symptom hypertension dizziness)
(symptom hypertension chest-pain)
(symptom hypertension shortness-of-breath)
(symptom heart-disease chest-pain)
(symptom heart-disease fatigue)
(symptom heart-disease shortness-of-breath)
(symptom heart-disease irregular-heartbeat)
(treats metformin diabetes)
(treats insulin diabetes)
(treats lifestyle-modification diabetes)
(treats beta-blockers hypertension)
(treats ace-inhibitors hypertension)
(treats diuretics hypertension)
(risk-factor diabetes obesity)
(risk-factor diabetes sedentary-lifestyle)
(risk-factor diabetes family-history)
(risk-factor hypertension high-sodium-diet)
(risk-factor hypertension stress)
(risk-factor hypertension obesity)
(diagnoses blood-glucose-test diabetes)
(diagnoses hba1c-test diabetes)
(diagnoses blood-pressure-test hypertension)
(diagnoses ecg heart-disease)
(diagnoses chest-xray heart-disease)
(imaging-modality xray chest)
(imaging-modality ct-scan brain)
(imaging-modality mri spine)
(imaging-modality ultrasound abdomen)
(= (get-symptoms $condition)
(match &self (symptom $condition $symptom) $symptom))
(= (get-treatments $condition)
(match &self (treats $treatment $condition) $treatment))
(= (get-risk-factors $condition)
(match &self (risk-factor $condition $factor) $factor))
(= (get-diagnostic-tests $condition)
(match &self (diagnoses $test $condition) $test))
(= (find-condition-by-symptom $symptom)
(match &self (symptom $condition $symptom) $condition))
"""
metta.run(medical_knowledge)
except Exception:
# If MeTTa initialization or knowledge load fails, disable MeTTa gracefully
metta = None
else:
metta = None
chat_proto = Protocol(spec=chat_protocol_spec)
sessions = {}
def create_text_chat(text: str) -> ChatMessage:
content = [TextContent(type="text", text=text)]
return ChatMessage(
timestamp=datetime.utcnow(),
msg_id=uuid4(),
content=content,
)
def query_metta_knowledge(query_type: str, entity: str) -> str:
if not METTA_AVAILABLE or metta is None:
return "MeTTa knowledge graph not available."
try:
query_map = {
"symptoms": f"!(get-symptoms {entity})",
"treatments": f"!(get-treatments {entity})",
"risk-factors": f"!(get-risk-factors {entity})",
"diagnostic-tests": f"!(get-diagnostic-tests {entity})",
"find-condition": f"!(find-condition-by-symptom {entity})"
}
if query_type not in query_map:
return "Unknown query type."
result = metta.run(query_map[query_type])
if result:
items = [str(item) for item in result[0] if str(item) != '()']
if items:
return ", ".join(items)
else:
return f"No {query_type} found in knowledge base for {entity}."
else:
return f"No results found."
except Exception as e:
return f"Error querying MeTTa: {str(e)}"
def analyze_with_metta_and_biogpt(query: str) -> str:
query_lower = query.lower()
metta_response = ""
conditions = ["diabetes", "hypertension", "heart-disease"]
query_types = {
"symptom": "symptoms",
"treatment": "treatments",
"treat": "treatments",
"risk": "risk-factors",
"diagnos": "diagnostic-tests",
"test": "diagnostic-tests"
}
detected_condition = None
detected_query_type = None
for condition in conditions:
if condition.replace("-", " ") in query_lower or condition in query_lower:
detected_condition = condition
break
for keyword, q_type in query_types.items():
if keyword in query_lower:
detected_query_type = q_type
break
if detected_condition and detected_query_type and METTA_AVAILABLE:
metta_result = query_metta_knowledge(detected_query_type, detected_condition)
metta_response = f"\n\nKnowledge Graph Insight ({detected_query_type} for {detected_condition}):\n{metta_result}\n"
biogpt_response = analyze_medical_text(query)
return metta_response + "\nBioGPT Analysis:\n" + biogpt_response
def analyze_medical_text(query: str, max_length: int = 200) -> str:
try:
result = text_pipeline(
query,
max_length=max_length,
num_return_sequences=1,
pad_token_id=text_tokenizer.eos_token_id
)
return result[0]['generated_text']
except Exception as e:
return f"Error in medical text analysis: {str(e)}"
def analyze_medical_image(image_data: bytes) -> str:
try:
image = Image.open(io.BytesIO(image_data))
if image.mode != 'RGB':
image = image.convert('RGB')
inputs = vision_feature_extractor(images=image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
outputs = vision_model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
probs = torch.nn.functional.softmax(logits, dim=-1)
confidence = probs[0][predicted_class].item()
label = vision_model.config.id2label[predicted_class]
return f"Medical Image Analysis (ViT):\n- Prediction: {label}\n- Confidence: {confidence:.2%}\n- Class ID: {predicted_class}"
except Exception as e:
return f"Error in medical image analysis: {str(e)}"
def process_medical_query(text: str) -> str:
text_lower = text.lower()
if any(keyword in text_lower for keyword in ['image', 'scan', 'x-ray', 'xray', 'mri', 'ct scan', 'picture']):
return "To analyze a medical image, provide an image URL starting with 'IMAGE:' followed by the URL."
response = analyze_with_metta_and_biogpt(text)
disclaimer = "\n\nDisclaimer: This is AI-generated information combining knowledge graph reasoning (MeTTa) and language models (BioGPT) for educational purposes. Always consult with qualified healthcare professionals for medical advice."
return "MedBridge AI - Hybrid Analysis\n" + response + disclaimer
@chat_proto.on_message(ChatMessage)
async def handle_message(ctx: Context, sender: str, msg: ChatMessage):
ctx.logger.info(f"Received message from {sender}")
await ctx.send(
sender,
ChatAcknowledgement(
timestamp=datetime.utcnow(),
acknowledged_msg_id=msg.msg_id
)
)
for item in msg.content:
if isinstance(item, StartSessionContent):
ctx.logger.info(f"Session started with {sender}")
session_id = str(uuid4())
sessions[sender] = {
"session_id": session_id,
"started_at": datetime.utcnow(),
"query_count": 0
}
metta_status = "Enabled" if METTA_AVAILABLE else "Disabled"
welcome_msg = (
"Welcome to MedBridge AI\n"
"Powered by SingularityNET MeTTa, BioGPT-Large, Vision Transformer, Fetch.ai uAgents\n"
f"MeTTa Knowledge Graph: {metta_status}\n\n"
"How can I assist you today?\n"
"Ask medical questions or send IMAGE:your-image-url"
)
response_message = create_text_chat(welcome_msg)
await ctx.send(sender, response_message)
elif isinstance(item, TextContent):
ctx.logger.info(f"Text from {sender}: {item.text[:100]}...")
if sender in sessions:
sessions[sender]["query_count"] += 1
try:
if item.text.startswith("IMAGE:"):
image_url = item.text[6:].strip()
response = requests.get(image_url, timeout=10)
result = analyze_medical_image(response.content)
response_message = create_text_chat(result)
await ctx.send(sender, response_message)
else:
result = process_medical_query(item.text)
response_message = create_text_chat(result)
await ctx.send(sender, response_message)
except Exception as e:
error_msg = create_text_chat(
f"Error processing your request: {str(e)}"
)
await ctx.send(sender, error_msg)
elif isinstance(item, EndSessionContent):
ctx.logger.info(f"Session ended with {sender}")
if sender in sessions:
query_count = sessions[sender]["query_count"]
goodbye_msg = create_text_chat(
f"Thank you for using MedBridge AI. "
f"I processed {query_count} queries during this session."
)
await ctx.send(sender, goodbye_msg)
del sessions[sender]
else:
ctx.logger.info(f"Unexpected content type from {sender}")
@chat_proto.on_message(ChatAcknowledgement)
async def handle_acknowledgement(ctx: Context, sender: str, msg: ChatAcknowledgement):
ctx.logger.info(f"Acknowledgement from {sender} for message {msg.acknowledged_msg_id}")
@chat_proto.on_interval(period=300.0)
async def cleanup_sessions(ctx: Context):
current_time = datetime.utcnow()
expired = []
for sender, session_data in sessions.items():
age = (current_time - session_data["started_at"]).total_seconds()
if age > 3600:
expired.append(sender)
for sender in expired:
del sessions[sender]
if expired:
ctx.logger.info(f"Cleaned up {len(expired)} expired sessions")
@chat_proto.on_interval(period=120.0)
async def log_agent_status(ctx: Context):
ctx.logger.info(f"MedBridge AI Status - Active Sessions: {len(sessions)}")
@agent.on_event("startup")
async def agent_startup(ctx: Context):
ctx.logger.info("=" * 70)
ctx.logger.info("MedBridge AI Agent Starting")
ctx.logger.info("=" * 70)
ctx.logger.info(f"Agent Name: {agent.name}")
ctx.logger.info(f"Agent Address: {agent.address}")
ctx.logger.info(f"ASI Alliance Technologies:")
ctx.logger.info(f" • Fetch.ai uAgents Framework: Enabled")
ctx.logger.info(f" • SingularityNET MeTTa: {'Enabled' if METTA_AVAILABLE else 'Not installed'}")
ctx.logger.info(f"AI Models:")
ctx.logger.info(f" • Text Model: {text_model_name}")
ctx.logger.info(f" • Vision Model: {vision_model_name}")
ctx.logger.info(f" • Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}")
ctx.logger.info(f"Protocols:")
ctx.logger.info(f" • ASI:One Compatible: Enabled (Chat Protocol)")
ctx.logger.info("=" * 70)
ctx.logger.info("MedBridge AI Agent Ready")
ctx.logger.info("=" * 70)
agent.include(chat_proto, publish_manifest=True)
print(f"Your agent's address is: {agent.address}")
if __name__ == "__main__":
agent.run()