import gradio as gr import os os.system("pip3 install torch transformers Pillow ensemble_transformers") import torch from transformers import AutoFeatureExtractor, AutoModelForImageClassification from PIL import Image from ensemble_transformers import EnsembleModelForImageClassification ensemble = EnsembleModelForImageClassification.from_multiple_pretrained( "tcvrishank/histo_train_vit", "tcvrishank/histo_train_segformer", "tcvrishank/histo_train_swin" ) candidate_labels = ["Benign", "InSitu", "Invasive", "Normal"] def return_prediction(image): with torch.no_grad(): outputs = ensemble(image, mean_pool = True) logits = outputs.logits[0] probs = logits.softmax(dim=-1).numpy() scores = probs.tolist() result = [ {"score": score, "label": candidate_label} for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0]) ] result = result[0] final = f"This histopathology image shows a cell population that indicates a risk score of {round(result['score'], 2) + 1}. Image suggests high risk of recurrence." return final demo = gr.Interface(fn=return_prediction, inputs="image", outputs="text") demo.launch()