napoles3d commited on
Commit
b5f99ea
·
verified ·
1 Parent(s): f4d00b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -20
app.py CHANGED
@@ -17,7 +17,6 @@ def vram_gb():
17
 
18
  def build_prompt(gene_list, species="Homo sapiens"):
19
  if isinstance(gene_list, str):
20
- # permitir lista separada por comas/espacios/nuevas líneas
21
  raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()]
22
  genes = ", ".join(raw)
23
  else:
@@ -41,7 +40,7 @@ def unload():
41
  def load_model(model_id, quantization):
42
  """
43
  Carga perezosa del modelo. Para 27B se recomienda A100 80GB.
44
- quantization: 'none' o '8bit' (requiere bitsandbytes).
45
  """
46
  if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None:
47
  return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"]
@@ -53,7 +52,7 @@ def load_model(model_id, quantization):
53
 
54
  kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True)
55
 
56
- if quantization == "8bit":
57
  try:
58
  import bitsandbytes as bnb # noqa: F401
59
  kwargs.update(dict(load_in_8bit=True))
@@ -69,7 +68,11 @@ def load_model(model_id, quantization):
69
  MODEL_CACHE["model"] = mdl
70
  return tok, mdl
71
 
72
- def infer(model_id, species, genes_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization):
 
 
 
 
73
  # chequeo sencillo de VRAM con guía para 27B
74
  mem = vram_gb()
75
  warn = ""
@@ -77,11 +80,17 @@ def infer(model_id, species, genes_text, max_new_tokens, temperature, top_p, top
77
  if mem < 60 and quantization != "8bit":
78
  warn = (
79
  f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB "
80
- f"o usar 8-bit (aun así puede ser insuficiente en T4)."
81
  )
82
 
83
  tok, mdl = load_model(model_id, quantization)
84
- prompt = build_prompt(genes_text, species=species)
 
 
 
 
 
 
85
  inputs = tok(prompt, return_tensors="pt")
86
  if torch.cuda.is_available():
87
  inputs = {k: v.to(mdl.device) for k, v in inputs.items()}
@@ -110,7 +119,7 @@ def infer(model_id, species, genes_text, max_new_tokens, temperature, top_p, top
110
  thread.start()
111
  for new_text in streamer:
112
  output_text += new_text
113
- yield (warn, prompt, output_text)
114
  thread.join()
115
 
116
  with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
@@ -118,19 +127,22 @@ with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
118
  """
119
  # C2S-Scale (Gemma-2) for single-cell biology
120
  Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión).
121
- - Modelos: `vandijklab/C2S-Scale-Gemma-2-2B` (ligero), `vandijklab/C2S-Scale-Gemma-2-27B` (pesado).
122
- - Selecciona GPU en Settings del Space para mejor rendimiento.
123
 
124
- **Nota:** 27B requiere GPU grande (idealmente A100 80GB). En T4, incluso con 8-bit, puede no cargar.
 
 
 
 
125
  """
126
  )
 
127
  with gr.Row():
128
  model_id = gr.Dropdown(
129
  choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE],
130
  value=DEFAULT_MODEL_SMALL,
131
  label="Modelo"
132
  )
133
- quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (experimental)")
134
  species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie")
135
  species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False)
136
 
@@ -148,20 +160,19 @@ with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
148
  top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
149
  repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty")
150
 
 
 
 
151
  warn_box = gr.Markdown("")
152
- prompt_box = gr.Textbox(label="Prompt efectivo", lines=8)
153
  output_box = gr.Textbox(label="Salida del modelo (stream)")
154
 
155
- def _species_value(sp, custom):
156
- return custom if sp == "Custom…" and custom.strip() else sp
157
-
158
  run_btn = gr.Button("🚀 Inferir tipo celular")
 
159
  run_btn.click(
160
- fn=lambda mid, sp, spc, genes, mx, temp, tp, tk, rp, q: infer(
161
- mid, _species_value(sp, spc), genes, mx, temp, tp, tk, rp, q
162
- ),
163
- inputs=[model_id, species, species_custom, genes_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization],
164
- outputs=[warn_box, prompt_box, output_box]
165
  )
166
 
167
  if __name__ == "__main__":
 
17
 
18
  def build_prompt(gene_list, species="Homo sapiens"):
19
  if isinstance(gene_list, str):
 
20
  raw = [g.strip() for g in gene_list.replace("\n", ",").split(",") if g.strip()]
21
  genes = ", ".join(raw)
22
  else:
 
40
  def load_model(model_id, quantization):
41
  """
42
  Carga perezosa del modelo. Para 27B se recomienda A100 80GB.
43
+ quantization: 'none' o '8bit' (requiere bitsandbytes si hay GPU).
44
  """
45
  if MODEL_CACHE["id"] == model_id and MODEL_CACHE["model"] is not None:
46
  return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"]
 
52
 
53
  kwargs = dict(torch_dtype=dtype, device_map=device_map, low_cpu_mem_usage=True)
54
 
55
+ if quantization == "8bit" and torch.cuda.is_available():
56
  try:
57
  import bitsandbytes as bnb # noqa: F401
58
  kwargs.update(dict(load_in_8bit=True))
 
68
  MODEL_CACHE["model"] = mdl
69
  return tok, mdl
70
 
71
+ def infer(model_id, species, species_custom, genes_text, prompt_manual,
72
+ max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization):
73
+ # especie efectiva
74
+ species_eff = species_custom.strip() if (species == "Custom…" and species_custom.strip()) else species
75
+
76
  # chequeo sencillo de VRAM con guía para 27B
77
  mem = vram_gb()
78
  warn = ""
 
80
  if mem < 60 and quantization != "8bit":
81
  warn = (
82
  f"⚠️ Detectada VRAM ~{mem:.1f}GB. Para 27B se recomienda A100 80GB "
83
+ f"o intentar 8-bit (en T4 puede no ser suficiente)."
84
  )
85
 
86
  tok, mdl = load_model(model_id, quantization)
87
+
88
+ # prompt: usa el manual si está provisto; si no, lo construimos
89
+ if prompt_manual and str(prompt_manual).strip():
90
+ prompt = str(prompt_manual).strip()
91
+ else:
92
+ prompt = build_prompt(genes_text, species=species_eff)
93
+
94
  inputs = tok(prompt, return_tensors="pt")
95
  if torch.cuda.is_available():
96
  inputs = {k: v.to(mdl.device) for k, v in inputs.items()}
 
119
  thread.start()
120
  for new_text in streamer:
121
  output_text += new_text
122
+ yield (warn, output_text)
123
  thread.join()
124
 
125
  with gr.Blocks(title="C2S-Scale (Gemma-2) — Single-cell Biology") as demo:
 
127
  """
128
  # C2S-Scale (Gemma-2) for single-cell biology
129
  Infiere **tipo celular** a partir de una *cell sentence* (genes ordenados por expresión).
 
 
130
 
131
+ **Modelos**:
132
+ - `vandijklab/C2S-Scale-Gemma-2-2B` (ligero; CPU o GPU)
133
+ - `vandijklab/C2S-Scale-Gemma-2-27B` (pesado; ideal A100 80GB)
134
+
135
+ **Nota:** El campo *Prompt efectivo* es editable. Si lo dejas vacío, el app generará uno automáticamente.
136
  """
137
  )
138
+
139
  with gr.Row():
140
  model_id = gr.Dropdown(
141
  choices=[DEFAULT_MODEL_SMALL, DEFAULT_MODEL_LARGE],
142
  value=DEFAULT_MODEL_SMALL,
143
  label="Modelo"
144
  )
145
+ quantization = gr.Radio(["none", "8bit"], value="none", label="Cuantización (GPU opcional)")
146
  species = gr.Dropdown(["Homo sapiens", "Mus musculus", "Danio rerio", "Custom…"], value="Homo sapiens", label="Especie")
147
  species_custom = gr.Textbox(value="", label="Especie (si elegiste Custom…)", visible=False)
148
 
 
160
  top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
161
  repetition_penalty = gr.Slider(0.8, 1.5, value=1.05, step=0.01, label="repetition_penalty")
162
 
163
+ # PROMPT EFECTIVO (editable por el usuario)
164
+ prompt_box = gr.Textbox(label="Prompt efectivo (opcional; déjalo vacío para autogenerar)", lines=8, interactive=True)
165
+
166
  warn_box = gr.Markdown("")
 
167
  output_box = gr.Textbox(label="Salida del modelo (stream)")
168
 
 
 
 
169
  run_btn = gr.Button("🚀 Inferir tipo celular")
170
+
171
  run_btn.click(
172
+ fn=infer,
173
+ inputs=[model_id, species, species_custom, genes_text, prompt_box,
174
+ max_new_tokens, temperature, top_p, top_k, repetition_penalty, quantization],
175
+ outputs=[warn_box, output_box]
 
176
  )
177
 
178
  if __name__ == "__main__":