|
from uagents import Agent, Context, Protocol |
|
from uagents.setup import fund_agent_if_low |
|
from datetime import datetime |
|
from uuid import uuid4 |
|
import os |
|
import uvicorn |
|
|
|
from uagents_core.contrib.protocols.chat import ( |
|
ChatAcknowledgement, |
|
ChatMessage, |
|
EndSessionContent, |
|
StartSessionContent, |
|
TextContent, |
|
chat_protocol_spec, |
|
) |
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from PIL import Image |
|
import torch |
|
import io |
|
import requests |
|
|
|
try: |
|
from hyperon import MeTTa, GroundedAtom, E |
|
METTA_AVAILABLE = True |
|
except ImportError: |
|
METTA_AVAILABLE = False |
|
|
|
AGENT_NAME = os.getenv("AGENT_NAME") |
|
AGENT_SEED = os.getenv("AGENT_SEED") |
|
|
|
agent = Agent( |
|
name=AGENT_NAME, |
|
seed=AGENT_SEED, |
|
port=7860, |
|
mailbox = True, |
|
endpoint=["https://nexusbert-medbridge.hf.space/submit"], |
|
) |
|
|
|
fund_agent_if_low(agent.wallet.address()) |
|
|
|
text_model_name = "microsoft/BioGPT-Large" |
|
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) |
|
text_model = AutoModelForCausalLM.from_pretrained(text_model_name) |
|
text_pipeline = pipeline( |
|
"text-generation", |
|
model=text_model, |
|
tokenizer=text_tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
vision_model_name = "google/vit-base-patch16-224" |
|
vision_feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name) |
|
vision_model = ViTForImageClassification.from_pretrained(vision_model_name) |
|
|
|
if torch.cuda.is_available(): |
|
vision_model = vision_model.to('cuda') |
|
|
|
if METTA_AVAILABLE and os.getenv("ENABLE_METTA", "").lower() == "true": |
|
|
|
try: |
|
try: |
|
metta = MeTTa(stdlib=False) |
|
except TypeError: |
|
|
|
metta = MeTTa() |
|
medical_knowledge = """ |
|
(symptom diabetes frequent-urination) |
|
(symptom diabetes increased-thirst) |
|
(symptom diabetes fatigue) |
|
(symptom diabetes blurred-vision) |
|
(symptom diabetes slow-healing) |
|
(symptom hypertension headache) |
|
(symptom hypertension dizziness) |
|
(symptom hypertension chest-pain) |
|
(symptom hypertension shortness-of-breath) |
|
(symptom heart-disease chest-pain) |
|
(symptom heart-disease fatigue) |
|
(symptom heart-disease shortness-of-breath) |
|
(symptom heart-disease irregular-heartbeat) |
|
(treats metformin diabetes) |
|
(treats insulin diabetes) |
|
(treats lifestyle-modification diabetes) |
|
(treats beta-blockers hypertension) |
|
(treats ace-inhibitors hypertension) |
|
(treats diuretics hypertension) |
|
(risk-factor diabetes obesity) |
|
(risk-factor diabetes sedentary-lifestyle) |
|
(risk-factor diabetes family-history) |
|
(risk-factor hypertension high-sodium-diet) |
|
(risk-factor hypertension stress) |
|
(risk-factor hypertension obesity) |
|
(diagnoses blood-glucose-test diabetes) |
|
(diagnoses hba1c-test diabetes) |
|
(diagnoses blood-pressure-test hypertension) |
|
(diagnoses ecg heart-disease) |
|
(diagnoses chest-xray heart-disease) |
|
(imaging-modality xray chest) |
|
(imaging-modality ct-scan brain) |
|
(imaging-modality mri spine) |
|
(imaging-modality ultrasound abdomen) |
|
(= (get-symptoms $condition) |
|
(match &self (symptom $condition $symptom) $symptom)) |
|
(= (get-treatments $condition) |
|
(match &self (treats $treatment $condition) $treatment)) |
|
(= (get-risk-factors $condition) |
|
(match &self (risk-factor $condition $factor) $factor)) |
|
(= (get-diagnostic-tests $condition) |
|
(match &self (diagnoses $test $condition) $test)) |
|
(= (find-condition-by-symptom $symptom) |
|
(match &self (symptom $condition $symptom) $condition)) |
|
""" |
|
metta.run(medical_knowledge) |
|
except Exception: |
|
|
|
metta = None |
|
else: |
|
metta = None |
|
|
|
chat_proto = Protocol(spec=chat_protocol_spec) |
|
sessions = {} |
|
|
|
def create_text_chat(text: str) -> ChatMessage: |
|
content = [TextContent(type="text", text=text)] |
|
return ChatMessage( |
|
timestamp=datetime.utcnow(), |
|
msg_id=uuid4(), |
|
content=content, |
|
) |
|
|
|
def query_metta_knowledge(query_type: str, entity: str) -> str: |
|
if not METTA_AVAILABLE or metta is None: |
|
return "MeTTa knowledge graph not available." |
|
try: |
|
query_map = { |
|
"symptoms": f"!(get-symptoms {entity})", |
|
"treatments": f"!(get-treatments {entity})", |
|
"risk-factors": f"!(get-risk-factors {entity})", |
|
"diagnostic-tests": f"!(get-diagnostic-tests {entity})", |
|
"find-condition": f"!(find-condition-by-symptom {entity})" |
|
} |
|
if query_type not in query_map: |
|
return "Unknown query type." |
|
result = metta.run(query_map[query_type]) |
|
if result: |
|
items = [str(item) for item in result[0] if str(item) != '()'] |
|
if items: |
|
return ", ".join(items) |
|
else: |
|
return f"No {query_type} found in knowledge base for {entity}." |
|
else: |
|
return f"No results found." |
|
except Exception as e: |
|
return f"Error querying MeTTa: {str(e)}" |
|
|
|
def analyze_with_metta_and_biogpt(query: str) -> str: |
|
query_lower = query.lower() |
|
metta_response = "" |
|
conditions = ["diabetes", "hypertension", "heart-disease"] |
|
query_types = { |
|
"symptom": "symptoms", |
|
"treatment": "treatments", |
|
"treat": "treatments", |
|
"risk": "risk-factors", |
|
"diagnos": "diagnostic-tests", |
|
"test": "diagnostic-tests" |
|
} |
|
detected_condition = None |
|
detected_query_type = None |
|
for condition in conditions: |
|
if condition.replace("-", " ") in query_lower or condition in query_lower: |
|
detected_condition = condition |
|
break |
|
for keyword, q_type in query_types.items(): |
|
if keyword in query_lower: |
|
detected_query_type = q_type |
|
break |
|
if detected_condition and detected_query_type and METTA_AVAILABLE: |
|
metta_result = query_metta_knowledge(detected_query_type, detected_condition) |
|
metta_response = f"\n\nKnowledge Graph Insight ({detected_query_type} for {detected_condition}):\n{metta_result}\n" |
|
biogpt_response = analyze_medical_text(query) |
|
return metta_response + "\nBioGPT Analysis:\n" + biogpt_response |
|
|
|
def analyze_medical_text(query: str, max_length: int = 200) -> str: |
|
try: |
|
result = text_pipeline( |
|
query, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
pad_token_id=text_tokenizer.eos_token_id |
|
) |
|
return result[0]['generated_text'] |
|
except Exception as e: |
|
return f"Error in medical text analysis: {str(e)}" |
|
|
|
def analyze_medical_image(image_data: bytes) -> str: |
|
try: |
|
image = Image.open(io.BytesIO(image_data)) |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
inputs = vision_feature_extractor(images=image, return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to('cuda') for k, v in inputs.items()} |
|
outputs = vision_model(**inputs) |
|
logits = outputs.logits |
|
predicted_class = logits.argmax(-1).item() |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
confidence = probs[0][predicted_class].item() |
|
label = vision_model.config.id2label[predicted_class] |
|
return f"Medical Image Analysis (ViT):\n- Prediction: {label}\n- Confidence: {confidence:.2%}\n- Class ID: {predicted_class}" |
|
except Exception as e: |
|
return f"Error in medical image analysis: {str(e)}" |
|
|
|
def process_medical_query(text: str) -> str: |
|
text_lower = text.lower() |
|
if any(keyword in text_lower for keyword in ['image', 'scan', 'x-ray', 'xray', 'mri', 'ct scan', 'picture']): |
|
return "To analyze a medical image, provide an image URL starting with 'IMAGE:' followed by the URL." |
|
response = analyze_with_metta_and_biogpt(text) |
|
disclaimer = "\n\nDisclaimer: This is AI-generated information combining knowledge graph reasoning (MeTTa) and language models (BioGPT) for educational purposes. Always consult with qualified healthcare professionals for medical advice." |
|
return "MedBridge AI - Hybrid Analysis\n" + response + disclaimer |
|
|
|
@chat_proto.on_message(ChatMessage) |
|
async def handle_message(ctx: Context, sender: str, msg: ChatMessage): |
|
ctx.logger.info(f"Received message from {sender}") |
|
await ctx.send( |
|
sender, |
|
ChatAcknowledgement( |
|
timestamp=datetime.utcnow(), |
|
acknowledged_msg_id=msg.msg_id |
|
) |
|
) |
|
for item in msg.content: |
|
if isinstance(item, StartSessionContent): |
|
ctx.logger.info(f"Session started with {sender}") |
|
session_id = str(uuid4()) |
|
sessions[sender] = { |
|
"session_id": session_id, |
|
"started_at": datetime.utcnow(), |
|
"query_count": 0 |
|
} |
|
metta_status = "Enabled" if METTA_AVAILABLE else "Disabled" |
|
welcome_msg = ( |
|
"Welcome to MedBridge AI\n" |
|
"Powered by SingularityNET MeTTa, BioGPT-Large, Vision Transformer, Fetch.ai uAgents\n" |
|
f"MeTTa Knowledge Graph: {metta_status}\n\n" |
|
"How can I assist you today?\n" |
|
"Ask medical questions or send IMAGE:your-image-url" |
|
) |
|
response_message = create_text_chat(welcome_msg) |
|
await ctx.send(sender, response_message) |
|
elif isinstance(item, TextContent): |
|
ctx.logger.info(f"Text from {sender}: {item.text[:100]}...") |
|
if sender in sessions: |
|
sessions[sender]["query_count"] += 1 |
|
try: |
|
if item.text.startswith("IMAGE:"): |
|
image_url = item.text[6:].strip() |
|
response = requests.get(image_url, timeout=10) |
|
result = analyze_medical_image(response.content) |
|
response_message = create_text_chat(result) |
|
await ctx.send(sender, response_message) |
|
else: |
|
result = process_medical_query(item.text) |
|
response_message = create_text_chat(result) |
|
await ctx.send(sender, response_message) |
|
except Exception as e: |
|
error_msg = create_text_chat( |
|
f"Error processing your request: {str(e)}" |
|
) |
|
await ctx.send(sender, error_msg) |
|
elif isinstance(item, EndSessionContent): |
|
ctx.logger.info(f"Session ended with {sender}") |
|
if sender in sessions: |
|
query_count = sessions[sender]["query_count"] |
|
goodbye_msg = create_text_chat( |
|
f"Thank you for using MedBridge AI. " |
|
f"I processed {query_count} queries during this session." |
|
) |
|
await ctx.send(sender, goodbye_msg) |
|
del sessions[sender] |
|
else: |
|
ctx.logger.info(f"Unexpected content type from {sender}") |
|
|
|
@chat_proto.on_message(ChatAcknowledgement) |
|
async def handle_acknowledgement(ctx: Context, sender: str, msg: ChatAcknowledgement): |
|
ctx.logger.info(f"Acknowledgement from {sender} for message {msg.acknowledged_msg_id}") |
|
|
|
@chat_proto.on_interval(period=300.0) |
|
async def cleanup_sessions(ctx: Context): |
|
current_time = datetime.utcnow() |
|
expired = [] |
|
for sender, session_data in sessions.items(): |
|
age = (current_time - session_data["started_at"]).total_seconds() |
|
if age > 3600: |
|
expired.append(sender) |
|
for sender in expired: |
|
del sessions[sender] |
|
if expired: |
|
ctx.logger.info(f"Cleaned up {len(expired)} expired sessions") |
|
|
|
@chat_proto.on_interval(period=120.0) |
|
async def log_agent_status(ctx: Context): |
|
ctx.logger.info(f"MedBridge AI Status - Active Sessions: {len(sessions)}") |
|
|
|
@agent.on_event("startup") |
|
async def agent_startup(ctx: Context): |
|
ctx.logger.info("=" * 70) |
|
ctx.logger.info("MedBridge AI Agent Starting") |
|
ctx.logger.info("=" * 70) |
|
ctx.logger.info(f"Agent Name: {agent.name}") |
|
ctx.logger.info(f"Agent Address: {agent.address}") |
|
ctx.logger.info(f"ASI Alliance Technologies:") |
|
ctx.logger.info(f" • Fetch.ai uAgents Framework: Enabled") |
|
ctx.logger.info(f" • SingularityNET MeTTa: {'Enabled' if METTA_AVAILABLE else 'Not installed'}") |
|
ctx.logger.info(f"AI Models:") |
|
ctx.logger.info(f" • Text Model: {text_model_name}") |
|
ctx.logger.info(f" • Vision Model: {vision_model_name}") |
|
ctx.logger.info(f" • Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}") |
|
ctx.logger.info(f"Protocols:") |
|
ctx.logger.info(f" • ASI:One Compatible: Enabled (Chat Protocol)") |
|
ctx.logger.info("=" * 70) |
|
ctx.logger.info("MedBridge AI Agent Ready") |
|
ctx.logger.info("=" * 70) |
|
|
|
agent.include(chat_proto, publish_manifest=True) |
|
|
|
print(f"Your agent's address is: {agent.address}") |
|
|
|
if __name__ == "__main__": |
|
agent.run() |