Commit
•
62e031b
1
Parent(s):
81c29d0
refactor
Browse files
app.py
CHANGED
@@ -6,7 +6,10 @@ import transformers
|
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
logging.basicConfig(level=logging.INFO)
|
9 |
-
|
|
|
|
|
|
|
10 |
|
11 |
if "googleflan" == os.environ.get("MODEL"):
|
12 |
model = "google/flan-t5-small"
|
@@ -22,8 +25,8 @@ elif "llama" == os.environ.get("MODEL"):
|
|
22 |
pipeline = transformers.pipeline(
|
23 |
"text-generation",
|
24 |
model=model,
|
25 |
-
|
26 |
-
torch_dtype="auto",
|
27 |
low_cpu_mem_usage=True,
|
28 |
device_map="auto",
|
29 |
token=os.environ.get("HF_TOKEN"),
|
|
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
logging.basicConfig(level=logging.INFO)
|
9 |
+
if torch.cuda.is_available():
|
10 |
+
logging.info("Running on GPU")
|
11 |
+
else:
|
12 |
+
logging.info("Running on CPU")
|
13 |
|
14 |
if "googleflan" == os.environ.get("MODEL"):
|
15 |
model = "google/flan-t5-small"
|
|
|
25 |
pipeline = transformers.pipeline(
|
26 |
"text-generation",
|
27 |
model=model,
|
28 |
+
torch_dtype=torch.float16,
|
29 |
+
# torch_dtype="auto",
|
30 |
low_cpu_mem_usage=True,
|
31 |
device_map="auto",
|
32 |
token=os.environ.get("HF_TOKEN"),
|