Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,6 @@ def vram_gb():
|
|
| 17 |
|
| 18 |
def build_prompt(gene_list, species="Homo sapiens"):
|
| 19 |
if isinstance(gene_list, str):
|
| 20 |
-
# permitir lista separada por comas/espacios/nuevas líneas
|
| 21 |
raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()]
|
| 22 |
genes = ", ".join(raw)
|
| 23 |
else:
|
|
@@ -41,7 +40,7 @@ def unload():
|
|
| 41 |
def load_model(model_id, quantization):
|
| 42 |
"""
|
| 43 |
Carga perezosa del modelo. Para 27B se recomienda A100 80GB.
|
| 44 |
-
quantization: 'none' o '8bit' (requiere bitsandbytes).
|
| 45 |
"""
|
| 46 |
if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None:
|
| 47 |
return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"]
|
|
@@ -53,7 +52,7 @@ def load_model(model_id, quantization):
|
|
| 53 |
|
| 54 |
kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True)
|
| 55 |
|
| 56 |
-
if quantization == "8bit":
|
| 57 |
try:
|
| 58 |
import bitsandbytes as bnb # noqa: F401
|
| 59 |
kwargs.update(dict(load_in_8bit=True))
|
|
@@ -69,7 +68,11 @@ def load_model(model_id, quantization):
|
|
| 69 |
MODEL_CACHE["model"] = mdl
|
| 70 |
return tok, mdl
|
| 71 |
|
| 72 |
-
def infer(model_id, species,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
# chequeo sencillo de VRAM con guía para 27B
|
| 74 |
mem = vram_gb()
|
| 75 |
warn = ""
|
|
@@ -77,11 +80,17 @@ def infer(model_id, species, genes_text, max_new_tokens, temperature, top_p, top
|
|
| 77 |
if mem < 60 and quantization != "8bit":
|
| 78 |
warn = (
|
| 79 |
f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB "
|
| 80 |
-
f"o
|
| 81 |
)
|
| 82 |
|
| 83 |
tok, mdl = load_model(model_id, quantization)
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
inputs = tok(prompt, return_tensors="pt")
|
| 86 |
if torch.cuda.is_available():
|
| 87 |
inputs = {k: v.to(mdl.device) for k, v in inputs.items()}
|
|
@@ -110,7 +119,7 @@ def infer(model_id, species, genes_text, max_new_tokens, temperature, top_p, top
|
|
| 110 |
thread.start()
|
| 111 |
for new_text in streamer:
|
| 112 |
output_text += new_text
|
| 113 |
-
yield (warn,
|
| 114 |
thread.join()
|
| 115 |
|
| 116 |
with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
|
|
@@ -118,19 +127,22 @@ with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
|
|
| 118 |
"""
|
| 119 |
# C2S-Scale (Gemma-2) for single-cell biology
|
| 120 |
Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión).
|
| 121 |
-
- Modelos: `vandijklab/C2S-Scale-Gemma-2-2B` (ligero), `vandijklab/C2S-Scale-Gemma-2-27B` (pesado).
|
| 122 |
-
- Selecciona GPU en Settings del Space para mejor rendimiento.
|
| 123 |
|
| 124 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
"""
|
| 126 |
)
|
|
|
|
| 127 |
with gr.Row():
|
| 128 |
model_id = gr.Dropdown(
|
| 129 |
choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE],
|
| 130 |
value=DEFAULT_MODEL_SMALL,
|
| 131 |
label="Modelo"
|
| 132 |
)
|
| 133 |
-
quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (
|
| 134 |
species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie")
|
| 135 |
species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False)
|
| 136 |
|
|
@@ -148,20 +160,19 @@ with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
|
|
| 148 |
top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
|
| 149 |
repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty")
|
| 150 |
|
|
|
|
|
|
|
|
|
|
| 151 |
warn_box = gr.Markdown("")
|
| 152 |
-
prompt_box = gr.Textbox(label="Prompt efectivo", lines=8)
|
| 153 |
output_box = gr.Textbox(label="Salida del modelo (stream)")
|
| 154 |
|
| 155 |
-
def _species_value(sp, custom):
|
| 156 |
-
return custom if sp == "Custom…" and custom.strip() else sp
|
| 157 |
-
|
| 158 |
run_btn = gr.Button("🚀 Inferir tipo celular")
|
|
|
|
| 159 |
run_btn.click(
|
| 160 |
-
fn=
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
outputs=[warn_box, prompt_box, output_box]
|
| 165 |
)
|
| 166 |
|
| 167 |
if __name__ == "__main__":
|
|
|
|
| 17 |
|
| 18 |
def build_prompt(gene_list, species="Homo sapiens"):
|
| 19 |
if isinstance(gene_list, str):
|
|
|
|
| 20 |
raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()]
|
| 21 |
genes = ", ".join(raw)
|
| 22 |
else:
|
|
|
|
| 40 |
def load_model(model_id, quantization):
|
| 41 |
"""
|
| 42 |
Carga perezosa del modelo. Para 27B se recomienda A100 80GB.
|
| 43 |
+
quantization: 'none' o '8bit' (requiere bitsandbytes si hay GPU).
|
| 44 |
"""
|
| 45 |
if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None:
|
| 46 |
return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"]
|
|
|
|
| 52 |
|
| 53 |
kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True)
|
| 54 |
|
| 55 |
+
if quantization == "8bit" and torch.cuda.is_available():
|
| 56 |
try:
|
| 57 |
import bitsandbytes as bnb # noqa: F401
|
| 58 |
kwargs.update(dict(load_in_8bit=True))
|
|
|
|
| 68 |
MODEL_CACHE["model"] = mdl
|
| 69 |
return tok, mdl
|
| 70 |
|
| 71 |
+
def infer(model_id, species, species_custom, genes_text, prompt_manual,
|
| 72 |
+
max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization):
|
| 73 |
+
# especie efectiva
|
| 74 |
+
species_eff = species_custom.strip() if (species == "Custom…" and species_custom.strip()) else species
|
| 75 |
+
|
| 76 |
# chequeo sencillo de VRAM con guía para 27B
|
| 77 |
mem = vram_gb()
|
| 78 |
warn = ""
|
|
|
|
| 80 |
if mem < 60 and quantization != "8bit":
|
| 81 |
warn = (
|
| 82 |
f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB "
|
| 83 |
+
f"o intentar 8-bit (en T4 puede no ser suficiente)."
|
| 84 |
)
|
| 85 |
|
| 86 |
tok, mdl = load_model(model_id, quantization)
|
| 87 |
+
|
| 88 |
+
# prompt: usa el manual si está provisto; si no, lo construimos
|
| 89 |
+
if prompt_manual and str(prompt_manual).strip():
|
| 90 |
+
prompt = str(prompt_manual).strip()
|
| 91 |
+
else:
|
| 92 |
+
prompt = build_prompt(genes_text, species=species_eff)
|
| 93 |
+
|
| 94 |
inputs = tok(prompt, return_tensors="pt")
|
| 95 |
if torch.cuda.is_available():
|
| 96 |
inputs = {k: v.to(mdl.device) for k, v in inputs.items()}
|
|
|
|
| 119 |
thread.start()
|
| 120 |
for new_text in streamer:
|
| 121 |
output_text += new_text
|
| 122 |
+
yield (warn, output_text)
|
| 123 |
thread.join()
|
| 124 |
|
| 125 |
with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
|
|
|
|
| 127 |
"""
|
| 128 |
# C2S-Scale (Gemma-2) for single-cell biology
|
| 129 |
Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión).
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
**Modelos**:
|
| 132 |
+
- `vandijklab/C2S-Scale-Gemma-2-2B` (ligero; CPU o GPU)
|
| 133 |
+
- `vandijklab/C2S-Scale-Gemma-2-27B` (pesado; ideal A100 80GB)
|
| 134 |
+
|
| 135 |
+
**Nota:** El campo *Prompt efectivo* es editable. Si lo dejas vacío, el app generará uno automáticamente.
|
| 136 |
"""
|
| 137 |
)
|
| 138 |
+
|
| 139 |
with gr.Row():
|
| 140 |
model_id = gr.Dropdown(
|
| 141 |
choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE],
|
| 142 |
value=DEFAULT_MODEL_SMALL,
|
| 143 |
label="Modelo"
|
| 144 |
)
|
| 145 |
+
quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (GPU opcional)")
|
| 146 |
species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie")
|
| 147 |
species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False)
|
| 148 |
|
|
|
|
| 160 |
top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
|
| 161 |
repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty")
|
| 162 |
|
| 163 |
+
# PROMPT EFECTIVO (editable por el usuario)
|
| 164 |
+
prompt_box = gr.Textbox(label="Prompt efectivo (opcional; déjalo vacío para autogenerar)", lines=8, interactive=True)
|
| 165 |
+
|
| 166 |
warn_box = gr.Markdown("")
|
|
|
|
| 167 |
output_box = gr.Textbox(label="Salida del modelo (stream)")
|
| 168 |
|
|
|
|
|
|
|
|
|
|
| 169 |
run_btn = gr.Button("🚀 Inferir tipo celular")
|
| 170 |
+
|
| 171 |
run_btn.click(
|
| 172 |
+
fn=infer,
|
| 173 |
+
inputs=[model_id, species, species_custom, genes_text, prompt_box,
|
| 174 |
+
max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization],
|
| 175 |
+
outputs=[warn_box, output_box]
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
if __name__ == "__main__":
|