Spaces:
Sleeping
Sleeping
# %% | |
from typing import List, Dict, tuple, Any | |
import os | |
from sqlalchemy import create_engine, text | |
import requests | |
def get_all_diseases_name(engine) -> List[List[str]]: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT * FROM Test.EntityEmbeddings | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
all_diseases = [row[1] for row in data if row[1] != "nan"] | |
return all_diseases | |
def get_uri_from_name(engine, name: str) -> str: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT uri FROM Test.EntityEmbeddings | |
WHERE label = '{name}' | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return data[0][0].split('/')[-1] | |
def get_most_similar_diseases_from_uri(engine, original_disease_uri: str, threshold: float = 0.8) -> List[str]: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT * FROM Test.EntityEmbeddings | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
all_diseases = [row[1] for row in data if row[1] != "nan"] | |
return all_diseases | |
def get_uri_from_name(engine, name: str) -> str: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT uri FROM Test.EntityEmbeddings | |
WHERE label = '{name}' | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return data[0][0].split('/')[-1] | |
def get_most_similar_diseases_from_uri(engine, original_disease_uri: str, threshold: float = 0.8) -> List[str]: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2, | |
VECTOR_COSINE(e1.embedding, e2.embedding) AS distance | |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}' | |
AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold} | |
AND e1.uri != e2.uri | |
ORDER BY distance DESC | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
similar_diseases = [(row[1].split('/')[-1], row[3], row[4]) for row in data if row[3] != "nan"] | |
return similar_diseases | |
def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]: | |
# Request: | |
# curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \ | |
# -H "accept: text/csv" | |
request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}" | |
response = requests.get(request_url, headers={"accept": "application/json"}) | |
return response.json() | |
def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]: | |
clinical_records = [] | |
for clinical_record_id in clinical_record_ids: | |
clinical_record_info = get_clinical_record_info(clinical_record_id) | |
clinical_records.append(clinical_record_info) | |
return clinical_records | |
def get_uris_of_similar_diseases(uri_list: List[str]) -> List[tuple[str, str, float]]: | |
uri_list = tuple(uri_list) | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance | |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
WHERE e1.uri IN {uri_list} AND e2.uri IN {uri_list} AND e1.uri != e2.uri | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return data | |
if __name__ == "__main__": | |
username = 'demo' | |
password = 'demo' | |
hostname = os.getenv('IRIS_HOSTNAME', 'localhost') | |
port = '1972' | |
namespace = 'USER' | |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
try: | |
engine = create_engine(CONNECTION_STRING) | |
diseases = get_most_similar_diseases_from_uri('C1843013') | |
for disease in diseases: | |
print(disease) | |
except Exception as e: | |
print(e) | |
print(get_uri_from_name(engine, 'Alzheimer disease 3')) | |
clinical_record_info = get_clinical_records_by_ids(['NCT00841061']) | |
print(clinical_record_info) | |