|
import os |
|
import sys |
|
import argparse |
|
import logging |
|
from pathlib import Path |
|
from typing import List, Dict, Any, Optional |
|
import warnings |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
from datasets import Dataset, DatasetDict |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('embedding_generation.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AffiliationEmbedder: |
|
def __init__( |
|
self, |
|
model_path: str = "./affiliation-clustering-0.3b", |
|
device: str = None, |
|
batch_size: int = 32, |
|
max_length: int = 512, |
|
use_fp16: bool = False |
|
): |
|
self.model_path = model_path |
|
self.batch_size = batch_size |
|
self.max_length = max_length |
|
self.use_fp16 = use_fp16 |
|
|
|
if device is None: |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
else: |
|
self.device = torch.device(device) |
|
|
|
logger.info(f"Using device: {self.device}") |
|
if self.device.type == 'cuda': |
|
logger.info(f"GPU: {torch.cuda.get_device_name()}") |
|
logger.info(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") |
|
|
|
self._load_model() |
|
|
|
def _load_model(self): |
|
logger.info(f"Loading model from {self.model_path}") |
|
|
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_path, |
|
trust_remote_code=True |
|
) |
|
|
|
self.model = AutoModel.from_pretrained( |
|
self.model_path, |
|
trust_remote_code=True |
|
) |
|
|
|
self.model = self.model.to(self.device) |
|
|
|
if self.use_fp16 and self.device.type == 'cuda': |
|
self.model = self.model.half() |
|
logger.info("Using FP16 mixed precision") |
|
|
|
self.model.eval() |
|
|
|
logger.info("Model loaded successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise |
|
|
|
def encode_batch(self, texts: List[str]) -> np.ndarray: |
|
encoded = self.tokenizer( |
|
texts, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors='pt' |
|
) |
|
|
|
encoded = {k: v.to(self.device) for k, v in encoded.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**encoded) |
|
|
|
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: |
|
embeddings = outputs.pooler_output |
|
else: |
|
token_embeddings = outputs.last_hidden_state |
|
attention_mask = encoded['attention_mask'].unsqueeze(-1) |
|
masked_embeddings = token_embeddings * attention_mask |
|
embeddings = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1) |
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
embeddings = embeddings.cpu().numpy() |
|
|
|
if self.use_fp16: |
|
embeddings = embeddings.astype(np.float32) |
|
|
|
return embeddings |
|
|
|
def process_dataset( |
|
self, |
|
data_path: str, |
|
output_path: str, |
|
checkpoint_interval: int = 1000 |
|
) -> None: |
|
|
|
logger.info(f"Processing dataset: {data_path}") |
|
|
|
df = pd.read_parquet(data_path) |
|
logger.info(f"Loaded {len(df)} samples") |
|
|
|
checkpoint_path = output_path.replace('.parquet', '_checkpoint.parquet') |
|
start_idx = 0 |
|
|
|
if os.path.exists(checkpoint_path): |
|
logger.info(f"Found checkpoint at {checkpoint_path}") |
|
checkpoint_df = pd.read_parquet(checkpoint_path) |
|
start_idx = len(checkpoint_df) |
|
logger.info(f"Resuming from index {start_idx}") |
|
|
|
all_embeddings = [] |
|
processed_rows = [] |
|
|
|
total_batches = (len(df) - start_idx + self.batch_size - 1) // self.batch_size |
|
|
|
with tqdm(total=total_batches, desc="Generating embeddings") as pbar: |
|
for i in range(start_idx, len(df), self.batch_size): |
|
batch_df = df.iloc[i:i+self.batch_size] |
|
texts = batch_df['affiliation_name'].tolist() |
|
|
|
try: |
|
batch_embeddings = self.encode_batch(texts) |
|
|
|
for j, embedding in enumerate(batch_embeddings): |
|
row_idx = i + j |
|
row_data = df.iloc[row_idx].to_dict() |
|
row_data['embedding'] = embedding |
|
processed_rows.append(row_data) |
|
|
|
if len(processed_rows) % checkpoint_interval == 0: |
|
self._save_checkpoint(processed_rows, checkpoint_path) |
|
logger.info(f"Checkpoint saved at {len(processed_rows)} samples") |
|
|
|
pbar.update(1) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing batch at index {i}: {e}") |
|
if processed_rows: |
|
self._save_checkpoint(processed_rows, checkpoint_path) |
|
raise |
|
|
|
result_df = pd.DataFrame(processed_rows) |
|
|
|
logger.info(f"Saving embeddings to {output_path}") |
|
result_df.to_parquet(output_path, compression='snappy') |
|
|
|
if os.path.exists(checkpoint_path): |
|
os.remove(checkpoint_path) |
|
logger.info("Checkpoint file removed") |
|
|
|
logger.info(f"Successfully generated embeddings for {len(result_df)} samples") |
|
|
|
embedding_dim = len(result_df['embedding'].iloc[0]) |
|
logger.info(f"Embedding dimension: {embedding_dim}") |
|
logger.info(f"Output file size: {os.path.getsize(output_path) / 1e6:.2f} MB") |
|
|
|
def _save_checkpoint(self, processed_rows: List[Dict], checkpoint_path: str): |
|
checkpoint_df = pd.DataFrame(processed_rows) |
|
checkpoint_df.to_parquet(checkpoint_path, compression='snappy') |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Generate embeddings for affiliation strings" |
|
) |
|
parser.add_argument( |
|
"--model-path", |
|
type=str, |
|
default="./affiliation-clustering-0.3b", |
|
help="Path to the pre-trained model directory" |
|
) |
|
parser.add_argument( |
|
"--data-dir", |
|
type=str, |
|
default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample", |
|
help="Directory containing the input parquet files" |
|
) |
|
parser.add_argument( |
|
"--output-dir", |
|
type=str, |
|
default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample-embeddings", |
|
help="Directory to save the output embeddings" |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=32, |
|
help="Batch size for processing" |
|
) |
|
parser.add_argument( |
|
"--max-length", |
|
type=int, |
|
default=512, |
|
help="Maximum sequence length for tokenization" |
|
) |
|
parser.add_argument( |
|
"--device", |
|
type=str, |
|
default=None, |
|
help="Device to use (cuda/cpu, auto-detect if not specified)" |
|
) |
|
parser.add_argument( |
|
"--use-fp16", |
|
action="store_true", |
|
help="Use FP16 mixed precision for faster processing" |
|
) |
|
parser.add_argument( |
|
"--checkpoint-interval", |
|
type=int, |
|
default=1000, |
|
help="Save checkpoint every N batches" |
|
) |
|
parser.add_argument( |
|
"--push-to-hub", |
|
action="store_true", |
|
help="Push the resulting dataset to Hugging Face Hub" |
|
) |
|
parser.add_argument( |
|
"--hub-dataset-id", |
|
type=str, |
|
default=None, |
|
help="Hugging Face Hub dataset ID (required if push-to-hub is set)" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
output_dir = Path(args.output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
embedder = AffiliationEmbedder( |
|
model_path=args.model_path, |
|
device=args.device, |
|
batch_size=args.batch_size, |
|
max_length=args.max_length, |
|
use_fp16=args.use_fp16 |
|
) |
|
|
|
data_dir = Path(args.data_dir) |
|
train_file = list(data_dir.glob("*_train.parquet"))[0] |
|
test_file = list(data_dir.glob("*_test.parquet"))[0] |
|
|
|
train_output = output_dir / "train_embeddings.parquet" |
|
test_output = output_dir / "test_embeddings.parquet" |
|
|
|
logger.info("Processing training dataset...") |
|
embedder.process_dataset( |
|
str(train_file), |
|
str(train_output), |
|
checkpoint_interval=args.checkpoint_interval |
|
) |
|
|
|
logger.info("Processing test dataset...") |
|
embedder.process_dataset( |
|
str(test_file), |
|
str(test_output), |
|
checkpoint_interval=args.checkpoint_interval |
|
) |
|
|
|
if args.push_to_hub: |
|
if not args.hub_dataset_id: |
|
logger.error("--hub-dataset-id is required when --push-to-hub is set") |
|
sys.exit(1) |
|
|
|
logger.info(f"Pushing dataset to Hugging Face Hub: {args.hub_dataset_id}") |
|
|
|
try: |
|
from huggingface_hub import HfApi, login |
|
|
|
token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN') |
|
if token: |
|
login(token=token) |
|
logger.info("Authenticated with Hugging Face Hub using token") |
|
else: |
|
logger.info("No HF token found in environment, attempting to use existing credentials") |
|
|
|
logger.info("Loading generated embeddings...") |
|
train_df = pd.read_parquet(train_output) |
|
test_df = pd.read_parquet(test_output) |
|
|
|
logger.info(f"Train dataset: {len(train_df)} samples") |
|
logger.info(f"Test dataset: {len(test_df)} samples") |
|
|
|
logger.info("Creating dataset dictionary...") |
|
dataset_dict = DatasetDict({ |
|
'train': Dataset.from_pandas(train_df), |
|
'test': Dataset.from_pandas(test_df) |
|
}) |
|
|
|
logger.info(f"Pushing to hub: {args.hub_dataset_id}") |
|
dataset_dict.push_to_hub( |
|
args.hub_dataset_id, |
|
private=False, |
|
commit_message="Add affiliation embeddings generated with affiliation-clustering-0.3b model" |
|
) |
|
logger.info(f"Dataset successfully pushed to {args.hub_dataset_id}") |
|
logger.info(f"View at: https://huggingface.co/datasets/{args.hub_dataset_id}") |
|
|
|
except ImportError as e: |
|
logger.error(f"Failed to import required libraries: {e}") |
|
logger.error("Make sure huggingface_hub and datasets are installed") |
|
sys.exit(1) |
|
except Exception as e: |
|
logger.error(f"Failed to push dataset to hub: {e}") |
|
logger.error(f"Error type: {type(e).__name__}") |
|
import traceback |
|
logger.error(f"Traceback: {traceback.format_exc()}") |
|
sys.exit(1) |
|
|
|
logger.info("Embedding generation completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |