napoles3d commited on
Commit
aff296f
·
verified ·
1 Parent(s): e18c8c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
+
7
+ DEFAULT_MODEL_SMALL = "vandijklab/C2S-Scale-Gemma-2-2B"
8
+ DEFAULT_MODEL_LARGE = "vandijklab/C2S-Scale-Gemma-2-27B"
9
+
10
+ MODEL_CACHE = {"id": None, "tokenizer": None, "model": None}
11
+
12
+ def vram_gb():
13
+ if torch.cuda.is_available():
14
+ props = torch.cuda.get_device_properties(0)
15
+ return props.total_memory / (1024**3)
16
+ return 0.0
17
+
18
+ def build_prompt(gene_list, species="Homo sapiens"):
19
+ if isinstance(gene_list, str):
20
+ # permitir lista separada por comas/espacios/nuevas líneas
21
+ raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()]
22
+ genes = ", ".join(raw)
23
+ else:
24
+ genes = ", ".join(gene_list)
25
+ return (
26
+ f"The following is a list of gene names ordered by descending expression level "
27
+ f"in a {species} cell. Your task is to give the cell type which this cell belongs "
28
+ f"to based on its gene expression.\n"
29
+ f"Cell sentence: {genes}.\n"
30
+ f"The cell type corresponding to these genes is:"
31
+ )
32
+
33
+ def unload():
34
+ MODEL_CACHE["id"] = None
35
+ MODEL_CACHE["tokenizer"] = None
36
+ MODEL_CACHE["model"] = None
37
+ gc.collect()
38
+ if torch.cuda.is_available():
39
+ torch.cuda.empty_cache()
40
+
41
+ def load_model(model_id, quantization):
42
+ """
43
+ Carga perezosa del modelo. Para 27B se recomienda A100 80GB.
44
+ quantization: 'none' o '8bit' (requiere bitsandbytes).
45
+ """
46
+ if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None:
47
+ return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"]
48
+
49
+ unload()
50
+
51
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
52
+ device_map = "auto" if torch.cuda.is_available() else {"": "cpu"}
53
+
54
+ kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True)
55
+
56
+ if quantization == "8bit":
57
+ try:
58
+ import bitsandbytes as bnb # noqa: F401
59
+ kwargs.update(dict(load_in_8bit=True))
60
+ except Exception:
61
+ # Si no está disponible, caemos a sin cuantizar
62
+ pass
63
+
64
+ tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
65
+ mdl = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval()
66
+
67
+ MODEL_CACHE["id"] = model_id
68
+ MODEL_CACHE["tokenizer"] = tok
69
+ MODEL_CACHE["model"] = mdl
70
+ return tok, mdl
71
+
72
+ def infer(model_id, species, genes_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization):
73
+ # chequeo sencillo de VRAM con guía para 27B
74
+ mem = vram_gb()
75
+ warn = ""
76
+ if "27B" in model_id:
77
+ if mem < 60 and quantization != "8bit":
78
+ warn = (
79
+ f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB "
80
+ f"o usar 8-bit (aun así puede ser insuficiente en T4)."
81
+ )
82
+
83
+ tok, mdl = load_model(model_id, quantization)
84
+ prompt = build_prompt(genes_text, species=species)
85
+ inputs = tok(prompt, return_tensors="pt")
86
+ if torch.cuda.is_available():
87
+ inputs = {k: v.to(mdl.device) for k, v in inputs.items()}
88
+
89
+ streamer = TextIteratorStreamer(tok, skip_special_tokens=True)
90
+ gen_kwargs = dict(
91
+ **inputs,
92
+ max_new_tokens=int(max_new_tokens),
93
+ do_sample=True,
94
+ temperature=float(temperature),
95
+ top_p=float(top_p),
96
+ top_k=int(top_k),
97
+ repetition_penalty=float(repetition_penalty),
98
+ eos_token_id=tok.eos_token_id,
99
+ streamer=streamer,
100
+ )
101
+
102
+ # streaming
103
+ import threading
104
+ output_text = ""
105
+ def _gen():
106
+ with torch.no_grad():
107
+ mdl.generate(**gen_kwargs)
108
+
109
+ thread = threading.Thread(target=_gen)
110
+ thread.start()
111
+ for new_text in streamer:
112
+ output_text += new_text
113
+ yield (warn, prompt, output_text)
114
+ thread.join()
115
+
116
+ with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
117
+ gr.Markdown(
118
+ """
119
+ # C2S-Scale (Gemma-2) for single-cell biology
120
+ Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión).
121
+ - Modelos: `vandijklab/C2S-Scale-Gemma-2-2B` (ligero), `vandijklab/C2S-Scale-Gemma-2-27B` (pesado).
122
+ - Selecciona GPU en Settings del Space para mejor rendimiento.
123
+
124
+ **Nota:** 27B requiere GPU grande (idealmente A100 80GB). En T4, incluso con 8-bit, puede no cargar.
125
+ """
126
+ )
127
+ with gr.Row():
128
+ model_id = gr.Dropdown(
129
+ choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE],
130
+ value=DEFAULT_MODEL_SMALL,
131
+ label="Modelo"
132
+ )
133
+ quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (experimental)")
134
+ species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie")
135
+ species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False)
136
+
137
+ def _toggle_species(choice):
138
+ return gr.update(visible=(choice == "Custom…"))
139
+ species.change(_toggle_species, species, species_custom)
140
+
141
+ example_genes = "MALAT1, RPLP0, RPL13A, ACTB, RPS27A, PTPRC, CD3D, CD3E, CCR7, IL7R, LTB, TRAC, CD27, CD4, CCR6, CXCR5"
142
+ genes_text = gr.Textbox(value=example_genes, lines=6, label="Cell sentence (lista de genes ordenados por expresión ↓)")
143
+
144
+ with gr.Accordion("Parámetros de generación", open=False):
145
+ max_new_tokens = gr.Slider(8, 256, value=64, step=1, label="max_new_tokens")
146
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
147
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="top_p")
148
+ top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
149
+ repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty")
150
+
151
+ warn_box = gr.Markdown("")
152
+ prompt_box = gr.Code(label="Prompt efectivo", language="text")
153
+ output_box = gr.Textbox(label="Salida del modelo (stream)")
154
+
155
+ def _species_value(sp, custom):
156
+ return custom if sp == "Custom…" and custom.strip() else sp
157
+
158
+ run_btn = gr.Button("🚀 Inferir tipo celular")
159
+ run_btn.click(
160
+ fn=lambda mid, sp, spc, genes, mx, temp, tp, tk, rp, q: infer(
161
+ mid, _species_value(sp, spc), genes, mx, temp, tp, tk, rp, q
162
+ ),
163
+ inputs=[model_id, species, species_custom, genes_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization],
164
+ outputs=[warn_box, prompt_box, output_box]
165
+ )
166
+
167
+ if __name__ == "__main__":
168
+ demo.launch()