import os import gc import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer DEFAULT_MODEL_SMALL = "vandijklab/C2S-Scale-Gemma-2-2B" DEFAULT_MODEL_LARGE = "vandijklab/C2S-Scale-Gemma-2-27B" MODEL_CACHE = {"id": None, "tokenizer": None, "model": None} def vram_gb(): if torch.cuda.is_available(): props = torch.cuda.get_device_properties(0) return props.total_memory / (1024**3) return 0.0 def build_prompt(gene_list, species="Homo sapiens"): if isinstance(gene_list, str): raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()] genes = ", ".join(raw) else: genes = ", ".join(gene_list) return ( f"The following is a list of gene names ordered by descending expression level " f"in a {species} cell. Your task is to give the cell type which this cell belongs " f"to based on its gene expression.\n" f"Cell sentence: {genes}.\n" f"The cell type corresponding to these genes is:" ) def unload(): MODEL_CACHE["id"] = None MODEL_CACHE["tokenizer"] = None MODEL_CACHE["model"] = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def load_model(model_id, quantization): """ Carga perezosa del modelo. Para 27B se recomienda A100 80GB. quantization: 'none' o '8bit' (requiere bitsandbytes si hay GPU). """ if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None: return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"] unload() dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 device_map = "auto" if torch.cuda.is_available() else {"": "cpu"} kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True) if quantization == "8bit" and torch.cuda.is_available(): try: import bitsandbytes as bnb # noqa: F401 kwargs.update(dict(load_in_8bit=True)) except Exception: # Si no está disponible, caemos a sin cuantizar pass tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) mdl = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval() MODEL_CACHE["id"] = model_id MODEL_CACHE["tokenizer"] = tok MODEL_CACHE["model"] = mdl return tok, mdl def infer(model_id, species, species_custom, genes_text, prompt_manual, max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization): # especie efectiva species_eff = species_custom.strip() if (species == "Custom…" and species_custom.strip()) else species # chequeo sencillo de VRAM con guía para 27B mem = vram_gb() warn = "" if "27B" in model_id: if mem < 60 and quantization != "8bit": warn = ( f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB " f"o intentar 8-bit (en T4 puede no ser suficiente)." ) tok, mdl = load_model(model_id, quantization) # prompt: usa el manual si está provisto; si no, lo construimos if prompt_manual and str(prompt_manual).strip(): prompt = str(prompt_manual).strip() else: prompt = build_prompt(genes_text, species=species_eff) inputs = tok(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(mdl.device) for k, v in inputs.items()} streamer = TextIteratorStreamer(tok, skip_special_tokens=True) gen_kwargs = dict( **inputs, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), repetition_penalty=float(repetition_penalty), eos_token_id=tok.eos_token_id, streamer=streamer, ) # streaming import threading output_text = "" def _gen(): with torch.no_grad(): mdl.generate(**gen_kwargs) thread = threading.Thread(target=_gen) thread.start() for new_text in streamer: output_text += new_text yield (warn, output_text) thread.join() with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo: gr.Markdown( """ # C2S-Scale (Gemma-2) for single-cell biology Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión). **Modelos**: - `vandijklab/C2S-Scale-Gemma-2-2B` (ligero; CPU o GPU) - `vandijklab/C2S-Scale-Gemma-2-27B` (pesado; ideal A100 80GB) **Nota:** El campo *Prompt efectivo* es editable. Si lo dejas vacío, el app generará uno automáticamente. """ ) with gr.Row(): model_id = gr.Dropdown( choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE], value=DEFAULT_MODEL_SMALL, label="Modelo" ) quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (GPU opcional)") species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie") species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False) def _toggle_species(choice): return gr.update(visible=(choice == "Custom…")) species.change(_toggle_species, species, species_custom) example_genes = "MALAT1, RPLP0, RPL13A, ACTB, RPS27A, PTPRC, CD3D, CD3E, CCR7, IL7R, LTB, TRAC, CD27, CD4, CCR6, CXCR5" genes_text = gr.Textbox(value=example_genes, lines=6, label="Cell sentence (lista de genes ordenados por expresión ↓)") with gr.Accordion("Parámetros de generación", open=False): max_new_tokens = gr.Slider(8, 256, value=64, step=1, label="max_new_tokens") temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="top_p") top_k = gr.Slider(1, 200, value=50, step=1, label="top_k") repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty") # PROMPT EFECTIVO (editable por el usuario) prompt_box = gr.Textbox(label="Prompt efectivo (opcional; déjalo vacío para autogenerar)", lines=8, interactive=True) warn_box = gr.Markdown("") output_box = gr.Textbox(label="Salida del modelo (stream)") run_btn = gr.Button("🚀 Inferir tipo celular") run_btn.click( fn=infer, inputs=[model_id, species, species_custom, genes_text, prompt_box, max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization], outputs=[warn_box, output_box] ) if __name__ == "__main__": demo.launch()