Spaces:
Running
on
Zero
Running
on
Zero
added more models
Browse files- app.py +24 -6
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -5,6 +5,8 @@ from transformers import (
|
|
| 5 |
Gemma3ForConditionalGeneration,
|
| 6 |
TextIteratorStreamer,
|
| 7 |
Gemma3Processor,
|
|
|
|
|
|
|
| 8 |
)
|
| 9 |
import spaces
|
| 10 |
import tempfile
|
|
@@ -20,12 +22,28 @@ dotenv_path = find_dotenv()
|
|
| 20 |
|
| 21 |
load_dotenv(dotenv_path)
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
input_processor = Gemma3Processor.from_pretrained(
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
torch_dtype=torch.bfloat16,
|
| 30 |
device_map="auto",
|
| 31 |
attn_implementation="eager",
|
|
@@ -157,7 +175,7 @@ def run(
|
|
| 157 |
tokenize=True,
|
| 158 |
return_dict=True,
|
| 159 |
return_tensors="pt",
|
| 160 |
-
).to(device=
|
| 161 |
|
| 162 |
streamer = TextIteratorStreamer(
|
| 163 |
input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
|
|
@@ -172,7 +190,7 @@ def run(
|
|
| 172 |
repetition_penalty=repetition_penalty,
|
| 173 |
do_sample=True,
|
| 174 |
)
|
| 175 |
-
t = Thread(target=
|
| 176 |
t.start()
|
| 177 |
|
| 178 |
output = ""
|
|
|
|
| 5 |
Gemma3ForConditionalGeneration,
|
| 6 |
TextIteratorStreamer,
|
| 7 |
Gemma3Processor,
|
| 8 |
+
Gemma3nForConditionalGeneration,
|
| 9 |
+
Gemma3ForCausalLM
|
| 10 |
)
|
| 11 |
import spaces
|
| 12 |
import tempfile
|
|
|
|
| 22 |
|
| 23 |
load_dotenv(dotenv_path)
|
| 24 |
|
| 25 |
+
model_27_id = os.getenv("MODEL_27_ID", "google/gemma-3-4b-it")
|
| 26 |
+
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-4b-it")
|
| 27 |
+
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3-4b-it")
|
| 28 |
|
| 29 |
+
input_processor = Gemma3Processor.from_pretrained(model_27_id)
|
| 30 |
|
| 31 |
+
model_27 = Gemma3ForConditionalGeneration.from_pretrained(
|
| 32 |
+
model_27_id,
|
| 33 |
+
torch_dtype=torch.bfloat16,
|
| 34 |
+
device_map="auto",
|
| 35 |
+
attn_implementation="eager",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
model_12 = Gemma3ForCausalLM.from_pretrained(
|
| 39 |
+
model_12_id,
|
| 40 |
+
torch_dtype=torch.bfloat16,
|
| 41 |
+
device_map="auto",
|
| 42 |
+
attn_implementation="eager",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
model_3n = Gemma3nForConditionalGeneration.from_pretrained(
|
| 46 |
+
model_3n_id,
|
| 47 |
torch_dtype=torch.bfloat16,
|
| 48 |
device_map="auto",
|
| 49 |
attn_implementation="eager",
|
|
|
|
| 175 |
tokenize=True,
|
| 176 |
return_dict=True,
|
| 177 |
return_tensors="pt",
|
| 178 |
+
).to(device=model_27.device, dtype=torch.bfloat16)
|
| 179 |
|
| 180 |
streamer = TextIteratorStreamer(
|
| 181 |
input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
|
|
|
|
| 190 |
repetition_penalty=repetition_penalty,
|
| 191 |
do_sample=True,
|
| 192 |
)
|
| 193 |
+
t = Thread(target=model_27.generate, kwargs=generate_kwargs)
|
| 194 |
t.start()
|
| 195 |
|
| 196 |
output = ""
|
requirements.txt
CHANGED
|
@@ -6,4 +6,5 @@ accelerate
|
|
| 6 |
pytest
|
| 7 |
loguru
|
| 8 |
python-dotenv
|
| 9 |
-
opencv-python
|
|
|
|
|
|
| 6 |
pytest
|
| 7 |
loguru
|
| 8 |
python-dotenv
|
| 9 |
+
opencv-python
|
| 10 |
+
timm
|