Spaces:
Paused
Paused
import logging | |
import runpod | |
import os | |
import shutil | |
import uuid | |
import json | |
import time | |
import subprocess | |
from typing import Dict, Any | |
from azure.storage.blob import BlobServiceClient | |
# Modify logging configuration to print to console and file | |
logging.basicConfig( | |
level=logging.DEBUG, # Change to DEBUG to capture more detailed logs | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler() # Add stream handler to print to console | |
] | |
) | |
def get_azure_connection_string(): | |
"""Get Azure connection string from environment variable""" | |
conn_string = "DefaultEndpointsProtocol=https;AccountName=transcribedblobstorage;AccountKey=1Z7yKPP5DLbxnoHdh7NmHgwg3dFLaDiYHUELdid7dzfzR6/DvkZnnzpJ30lrXIMhtD5GYKo+71jP+AStC1TEvA==;EndpointSuffix=core.windows.net" | |
if not conn_string: | |
raise ValueError("Azure Storage connection string not found in environment variables") | |
return conn_string | |
def upload_file(file_path: str) -> str: | |
if not os.path.isfile(file_path): | |
raise FileNotFoundError(f"The specified file does not exist: {file_path}") | |
container_name = "saasdev" | |
connection_string = get_azure_connection_string() | |
blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
container_client = blob_service_client.get_container_client(container_name) | |
# Generate a unique blob name using UUID | |
blob_name = f"{uuid.uuid4()}.pdf" | |
with open(file_path, 'rb') as file: | |
blob_client = container_client.get_blob_client(blob_name) | |
blob_client.upload_blob(file) | |
logging.info(f"File uploaded to blob: {blob_name}") | |
return blob_name | |
def download_blob(blob_name: str, download_file_path: str) -> None: | |
"""Download a file from Azure Blob Storage""" | |
container_name = "saasdev" | |
connection_string = get_azure_connection_string() | |
blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
container_client = blob_service_client.get_container_client(container_name) | |
blob_client = container_client.get_blob_client(blob_name) | |
os.makedirs(os.path.dirname(download_file_path), exist_ok=True) | |
with open(download_file_path, "wb") as download_file: | |
download_stream = blob_client.download_blob() | |
download_file.write(download_stream.readall()) | |
logging.info(f"Blob '{blob_name}' downloaded to '{download_file_path}'") | |
def clean_directory(directory: str) -> None: | |
"""Clean up a directory by removing all files and subdirectories""" | |
if os.path.exists(directory): | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.remove(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
logging.error(f'Failed to delete {file_path}. Reason: {e}') | |
def handler(job: Dict[str, Any]) -> Dict[str, Any]: | |
start_time = time.time() | |
logging.info("Handler function started") | |
job_input = job.get('input', {}) | |
required_fields = ['pdf_file', 'system_prompt', 'model_name', 'max_step', 'learning_rate', 'epochs'] | |
missing_fields = [field for field in required_fields if field not in job_input] | |
if missing_fields: | |
return { | |
"status": "error", | |
"error": f"Missing required fields: {', '.join(missing_fields)}" | |
} | |
work_dir = os.path.abspath(f"/tmp/work_{str(uuid.uuid4())}") | |
try: | |
os.makedirs(work_dir, exist_ok=True) | |
logging.info(f"Working directory created: {work_dir}") | |
# Upload PDF to Blob | |
pdf_path = job_input['pdf_file'] | |
generated_blob_name = upload_file(pdf_path) | |
logging.info(f"PDF uploaded with blob name: {generated_blob_name}") | |
# Download the uploaded PDF using the internally generated blob name | |
downloaded_path = os.path.join(work_dir, "Downloaded_PDF.pdf") | |
download_blob(generated_blob_name, downloaded_path) | |
logging.info(f"PDF downloaded to: {downloaded_path}") | |
# Save pipeline input as JSON | |
pipeline_input_path = os.path.join(work_dir, "pipeline_input.json") | |
pipeline_input = { | |
"pdf_file": downloaded_path, | |
"system_prompt": job_input['system_prompt'], | |
"model_name": job_input['model_name'], | |
"max_step": job_input['max_step'], | |
"learning_rate": job_input['learning_rate'], | |
"epochs": job_input['epochs'] | |
} | |
with open(pipeline_input_path, 'w') as f: | |
json.dump(pipeline_input, f) | |
# Run fine-tuning and evaluation | |
return run_pipeline_and_evaluate(pipeline_input_path, job_input['model_name'], start_time) | |
except Exception as e: | |
error_message = f"Job failed after {time.time() - start_time:.2f} seconds: {str(e)}" | |
logging.error(error_message) | |
return { | |
"status": "error", | |
"error": error_message | |
} | |
finally: | |
try: | |
clean_directory(work_dir) | |
os.rmdir(work_dir) | |
except Exception as e: | |
logging.error(f"Failed to clean up working directory: {str(e)}") | |
def run_pipeline_and_evaluate(pipeline_input_path: str, model_name: str, start_time: float) -> Dict[str, Any]: | |
try: | |
# Suppress logging output | |
logging.getLogger().setLevel(logging.ERROR) | |
# Read the pipeline input file | |
with open(pipeline_input_path, 'r') as f: | |
pipeline_input = json.load(f) | |
# Convert the input to a JSON string for passing as an argument | |
pipeline_input_str = json.dumps(pipeline_input) | |
# Run fine-tuning pipeline with JSON string as argument | |
# logging.info(f"Running pipeline with input: {pipeline_input_str[:100]}...") | |
finetuning_result = subprocess.run( | |
['python3', 'Finetuning_Pipeline.py', pipeline_input_str], | |
capture_output=True, | |
text=True, | |
check=True | |
) | |
logging.info("Fine-tuning completed successfully") | |
# Run evaluation | |
evaluation_input = json.dumps({"model_name": model_name}) | |
result = subprocess.run( | |
['python3', 'VLLM_evaluation.py', evaluation_input], | |
capture_output=True, | |
text=True, | |
check=True | |
) | |
try: | |
# Extract JSON part from stdout | |
output_lines = result.stdout.splitlines() | |
for line in reversed(output_lines): | |
try: | |
evaluation_results = json.loads(line) | |
if "average_semantic_score" in evaluation_results and "average_bleu_score" in evaluation_results: | |
break | |
except json.JSONDecodeError: | |
continue | |
else: | |
# If no valid JSON is found, fall back to raw output | |
evaluation_results = {"raw_output": result.stdout} | |
except Exception as e: | |
evaluation_results = {"error": f"Failed to process evaluation output: {str(e)}"} | |
# Print only the JSON part to stdout for capturing in Gradio | |
print(json.dumps({ | |
"status": "success", | |
"model_name": f"PharynxAI/{model_name}", | |
"processing_time": time.time() - start_time, | |
"evaluation_results": evaluation_results | |
})) | |
return { | |
"status": "success", | |
"model_name": f"PharynxAI/{model_name}", | |
"processing_time": time.time() - start_time, | |
"evaluation_results": evaluation_results | |
} | |
except subprocess.CalledProcessError as e: | |
error_message = f"Pipeline process failed: {e.stderr}" | |
logging.error(error_message) | |
return { | |
"status": "error", | |
"error": error_message, | |
# "stdout": e.stdout, | |
# "stderr": e.stderr | |
} | |
except Exception as e: | |
error_message = f"Pipeline execution failed: {str(e)}" | |
logging.error(error_message) | |
return { | |
"status": "error", | |
"error": error_message | |
} | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
runpod.serverless.start({"handler": handler}) |