import bz2 import shutil import tempfile from pathlib import Path import gradio as gr import pypythia.msa import pypythia.prediction import pypythia.predictor import pypythia.raxmlng def get_default_raxmlng(): version = "1.1.0" uncompressed_raxmlng = Path.home() / f"raxml-ng-v{version}-linux-64" if not uncompressed_raxmlng.exists(): compressed_raxmlng = Path(__file__).parent / f"raxml-ng-v{version}-linux-64.bz2" with bz2.BZ2File(compressed_raxmlng) as bz, uncompressed_raxmlng.open( "wb" ) as rax: shutil.copyfileobj(bz, rax) uncompressed_raxmlng.chmod(755) return uncompressed_raxmlng def predict_difficulty(uploaded_file): predictor_file = ( Path(pypythia.__file__).parent / "predictors" / "predictor_lgb_v1.0.0.pckl" ) predictor = pypythia.predictor.DifficultyPredictor(predictor_file.open("rb")) raxmlng = pypythia.raxmlng.RAxMLNG( shutil.which("raxml-ng") or get_default_raxmlng() ) with tempfile.NamedTemporaryFile() as msa_file: uploaded_file.seek(0) shutil.copyfileobj(uploaded_file, msa_file) msa_file.flush() msa = pypythia.msa.MSA(msa_file.name) msa_features = pypythia.prediction.get_all_features(raxmlng, msa) difficulty = predictor.predict(msa_features) return difficulty, msa_features pythia_demo = gr.Interface( predict_difficulty, gr.File(label="MSA file (.phy or .msa)"), [ gr.Number(label="Difficulty", precision=5), gr.JSON(label="Features used for prediction"), ], ) pythia_demo.launch()