Spaces:
Running
Running
import re | |
import spacy | |
from nltk.tokenize import sent_tokenize, word_tokenize | |
import nltk | |
nltk.download('punkt_tab') | |
#import coreferee | |
import copy | |
from sentence_transformers import SentenceTransformer, util | |
from sklearn.cluster import DBSCAN | |
from sklearn.metrics.pairwise import cosine_distances | |
from collections import defaultdict | |
import numpy as np | |
#from mtdna_classifier import infer_fromQAModel | |
# 1. SENTENCE-BERT MODEL | |
# Step 1: Preprocess the text | |
def normalize_text(text): | |
# Normalize various separators to "-" | |
text = re.sub(r'\s*(β+|β+|--+>|β>|->|-->|to|β|β|β|β‘)\s*', '-', text, flags=re.IGNORECASE) | |
# Fix GEN10GEN30 β GEN10-GEN30 | |
text = re.sub(r'\b([a-zA-Z]+)(\d+)(\1)(\d+)\b', r'\1\2-\1\4', text) | |
# Fix GEN10-30 β GEN10-GEN30 | |
text = re.sub(r'\b([a-zA-Z]+)(\d+)-(\d+)\b', r'\1\2-\1\3', text) | |
return text | |
def preprocess_text(text): | |
normalized = normalize_text(text) | |
sentences = sent_tokenize(normalized) | |
return [re.sub(r"[^a-zA-Z0-9\s\-]", "", s).strip() for s in sentences] | |
# Before step 2, check NLP cache to avoid calling it muliple times: | |
# Global model cache | |
_spacy_models = {} | |
def get_spacy_model(model_name, add_coreferee=False): | |
global _spacy_models | |
if model_name not in _spacy_models: | |
nlp = spacy.load(model_name) | |
if add_coreferee and "coreferee" not in nlp.pipe_names: | |
nlp.add_pipe("coreferee") | |
_spacy_models[model_name] = nlp | |
return _spacy_models[model_name] | |
# Step 2: NER to Extract Locations and Sample Names | |
def extract_entities(text, sample_id=None): | |
nlp = get_spacy_model("en_core_web_sm") | |
doc = nlp(text) | |
# Filter entities by GPE, but exclude things that match sample ID format | |
gpe_candidates = [ent.text for ent in doc.ents if ent.label_ == "GPE"] | |
# Remove entries that match SAMPLE ID patterns like XXX123 or similar | |
gpe_filtered = [gpe for gpe in gpe_candidates if not re.fullmatch(r'[A-Z]{2,5}\d{2,4}', gpe.strip())] | |
# Optional: further filter known invalid patterns (e.g., things shorter than 3 chars, numeric only) | |
gpe_filtered = [gpe for gpe in gpe_filtered if len(gpe) > 2 and not gpe.strip().isdigit()] | |
if sample_id is None: | |
return list(set(gpe_filtered)), [] | |
else: | |
sample_prefix = re.match(r'[A-Z]+', sample_id).group() | |
samples = re.findall(rf'{sample_prefix}\d+', text) | |
return list(set(gpe_filtered)), list(set(samples)) | |
# Step 3: Build a Soft Matching Layer | |
# Handle patterns like "BRU1βBRU20" and identify BRU18 as part of it. | |
def is_sample_in_range(sample_id, sentence): | |
# Match prefix up to digits | |
sample_prefix_match = re.match(r'^([A-Z0-9]+?)(?=\d+$)', sample_id) | |
sample_number_match = re.search(r'(\d+)$', sample_id) | |
if not sample_prefix_match or not sample_number_match: | |
return False | |
sample_prefix = sample_prefix_match.group(1) | |
sample_number = int(sample_number_match.group(1)) | |
sentence = normalize_text(sentence) | |
# Case 1: Full prefix on both sides | |
pattern1 = rf'{sample_prefix}(\d+)\s*-\s*{sample_prefix}(\d+)' | |
for match in re.findall(pattern1, sentence): | |
start, end = int(match[0]), int(match[1]) | |
if start <= sample_number <= end: | |
return True | |
# Case 2: Prefix only on first number | |
pattern2 = rf'{sample_prefix}(\d+)\s*-\s*(\d+)' | |
for match in re.findall(pattern2, sentence): | |
start, end = int(match[0]), int(match[1]) | |
if start <= sample_number <= end: | |
return True | |
return False | |
# Step 4: Use coreferree to merge the sentences have same coreference # still cannot cause packages conflict | |
# ========== HEURISTIC GROUP β LOCATION MAPPERS ========== | |
# === Generalized version to replace your old extract_sample_to_group_general === | |
# === Generalized version to replace your old extract_group_to_location_general === | |
def extract_population_locations(text): | |
text = normalize_text(text) | |
pattern = r'([A-Za-z ,\-]+)\n([A-Z]+\d*)\n([A-Za-z ,\-]+)\n([A-Za-z ,\-]+)' | |
pop_to_location = {} | |
for match in re.finditer(pattern, text, flags=re.IGNORECASE): | |
_, pop_code, region, country = match.groups() | |
pop_to_location[pop_code.upper()] = f"{region.strip()}\n{country.strip()}" | |
return pop_to_location | |
def extract_sample_ranges(text): | |
text = normalize_text(text) | |
# Updated pattern to handle punctuation and line breaks | |
pattern = r'\b([A-Z0-9]+\d+)[β\-]([A-Z0-9]+\d+)[,:\.\s]*([A-Z0-9]+\d+)\b' | |
sample_to_pop = {} | |
for match in re.finditer(pattern, text, flags=re.IGNORECASE): | |
start_id, end_id, pop_code = match.groups() | |
start_prefix = re.match(r'^([A-Z0-9]+?)(?=\d+$)', start_id, re.IGNORECASE).group(1).upper() | |
end_prefix = re.match(r'^([A-Z0-9]+?)(?=\d+$)', end_id, re.IGNORECASE).group(1).upper() | |
if start_prefix != end_prefix: | |
continue | |
start_num = int(re.search(r'(\d+)$', start_id).group()) | |
end_num = int(re.search(r'(\d+)$', end_id).group()) | |
for i in range(start_num, end_num + 1): | |
sample_id = f"{start_prefix}{i:03d}" | |
sample_to_pop[sample_id] = pop_code.upper() | |
return sample_to_pop | |
def filter_context_for_sample(sample_id, full_text, window_size=2): | |
# Normalize and tokenize | |
full_text = normalize_text(full_text) | |
sentences = sent_tokenize(full_text) | |
# Step 1: Find indices with direct mention or range match | |
match_indices = [ | |
i for i, s in enumerate(sentences) | |
if sample_id in s or is_sample_in_range(sample_id, s) | |
] | |
# Step 2: Get sample β group mapping from full text | |
sample_to_group = extract_sample_ranges(full_text) | |
group_id = sample_to_group.get(sample_id) | |
# Step 3: Find group-related sentences | |
group_indices = [] | |
if group_id: | |
for i, s in enumerate(sentences): | |
if group_id in s: | |
group_indices.append(i) | |
# Step 4: Collect sentences within window | |
selected_indices = set() | |
if len(match_indices + group_indices) > 0: | |
for i in match_indices + group_indices: | |
start = max(0, i - window_size) | |
end = min(len(sentences), i + window_size + 1) | |
selected_indices.update(range(start, end)) | |
filtered_sentences = [sentences[i] for i in sorted(selected_indices)] | |
return " ".join(filtered_sentences) | |
return full_text | |
# Load the SpaCy transformer model with coreferee | |
def mergeCorefSen(text): | |
sen = preprocess_text(text) | |
return sen | |
# Before step 5 and below, let check transformer cache to avoid calling again | |
# Global SBERT model cache | |
_sbert_models = {} | |
def get_sbert_model(model_name="all-MiniLM-L6-v2"): | |
global _sbert_models | |
if model_name not in _sbert_models: | |
_sbert_models[model_name] = SentenceTransformer(model_name) | |
return _sbert_models[model_name] | |
# Step 5: Sentence-BERT retriever β Find top paragraphs related to keyword. | |
'''Use sentence transformers to embed the sentence that mentions the sample and | |
compare it to sentences that mention locations.''' | |
def find_top_para(sample_id, text,top_k=5): | |
sentences = mergeCorefSen(text) | |
model = get_sbert_model("all-mpnet-base-v2") | |
embeddings = model.encode(sentences, convert_to_tensor=True) | |
# Find the sentence that best matches the sample_id | |
sample_matches = [s for s in sentences if sample_id in s or is_sample_in_range(sample_id, s)] | |
if not sample_matches: | |
return [],"No context found for sample" | |
sample_embedding = model.encode(sample_matches[0], convert_to_tensor=True) | |
cos_scores = util.pytorch_cos_sim(sample_embedding, embeddings)[0] | |
# Get top-k most similar sentence indices | |
top_indices = cos_scores.argsort(descending=True)[:top_k] | |
return top_indices, sentences | |
# Step 6: DBSCAN to cluster the group of similar paragraphs. | |
def clusterPara(tokens): | |
# Load Sentence-BERT model | |
sbert_model = get_sbert_model("all-mpnet-base-v2") | |
sentence_embeddings = sbert_model.encode(tokens) | |
# Compute cosine distance matrix | |
distance_matrix = cosine_distances(sentence_embeddings) | |
# DBSCAN clustering | |
clustering_model = DBSCAN(eps=0.3, min_samples=1, metric="precomputed") | |
cluster_labels = clustering_model.fit_predict(distance_matrix) | |
# Group sentences by cluster | |
clusters = defaultdict(list) | |
cluster_embeddings = defaultdict(list) | |
sentence_to_cluster = {} | |
for i, label in enumerate(cluster_labels): | |
clusters[label].append(tokens[i]) | |
cluster_embeddings[label].append(sentence_embeddings[i]) | |
sentence_to_cluster[tokens[i]] = label | |
# Compute cluster centroids | |
centroids = { | |
label: np.mean(embs, axis=0) | |
for label, embs in cluster_embeddings.items() | |
} | |
return clusters, sentence_to_cluster, centroids | |
def rankSenFromCluster(clusters, sentence_to_cluster, centroids, target_sentence): | |
target_cluster = sentence_to_cluster[target_sentence] | |
target_centroid = centroids[target_cluster] | |
sen_rank = [] | |
sen_order = list(sentence_to_cluster.keys()) | |
# Compute distances to other cluster centroids | |
dists = [] | |
for label, centroid in centroids.items(): | |
dist = cosine_distances([target_centroid], [centroid])[0][0] | |
dists.append((label, dist)) | |
dists.sort(key=lambda x: x[1]) # sort by proximity | |
for d in dists: | |
cluster = clusters[d[0]] | |
for sen in cluster: | |
if sen != target_sentence: | |
sen_rank.append(sen_order.index(sen)) | |
return sen_rank | |
# Step 7: Final Inference Wrapper | |
def infer_location_for_sample(sample_id, context_text): | |
# Go through each of the top sentences in order | |
top_indices, sentences = find_top_para(sample_id, context_text,top_k=5) | |
if top_indices==[] or sentences == "No context found for sample": | |
return "No clear location found in top matches" | |
clusters, sentence_to_cluster, centroids = clusterPara(sentences) | |
topRankSen_DBSCAN = [] | |
mostTopSen = "" | |
locations = "" | |
i = 0 | |
while len(locations) == 0 or i < len(top_indices): | |
# Firstly, start with the top-ranked Sentence-BERT result | |
idx = top_indices[i] | |
best_sentence = sentences[idx] | |
if i == 0: | |
mostTopSen = best_sentence | |
locations, _ = extract_entities(best_sentence, sample_id) | |
if locations: | |
return locations | |
# If no location, then look for sample overlap in the same DBSCAN cluster | |
# Compute distances to other cluster centroids | |
if len(topRankSen_DBSCAN)==0 and mostTopSen: | |
topRankSen_DBSCAN = rankSenFromCluster(clusters, sentence_to_cluster, centroids, mostTopSen) | |
if i >= len(topRankSen_DBSCAN): break | |
idx_DBSCAN = topRankSen_DBSCAN[i] | |
best_sentence_DBSCAN = sentences[idx_DBSCAN] | |
locations, _ = extract_entities(best_sentence, sample_id) | |
if locations: | |
return locations | |
# If no, then backtrack to next best Sentence-BERT sentence (such as 2nd rank sentence), and repeat step 1 and 2 until run out | |
i += 1 | |
# Last resort: LLM (e.g. chatGPT, deepseek, etc.) | |
#if len(locations) == 0: | |
return "No clear location found in top matches" | |