Spaces:
Runtime error
Runtime error
ricardo-lsantos
commited on
Commit
•
b760fd0
1
Parent(s):
f4ed0cd
Commented torch_directml
Browse files- AI/question_answering.py +6 -6
- AI/sentiment_analysis.py +6 -6
- AI/summarization.py +6 -6
- AI/text_generation.py +6 -6
- AI/zero_shot_classification.py +6 -6
- requirements.txt +0 -1
AI/question_answering.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
-
import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
-
elif DEVICE == "directml":
|
17 |
-
|
18 |
-
|
19 |
return device
|
20 |
|
21 |
def loadGenerator(device):
|
@@ -33,5 +33,5 @@ def clearCache(DEVICE, generator):
|
|
33 |
generator.tokenizer.save_pretrained("cache")
|
34 |
generator.model.save_pretrained("cache")
|
35 |
del generator
|
36 |
-
if DEVICE == "directml":
|
37 |
-
|
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
+
# import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
+
# elif DEVICE == "directml":
|
17 |
+
# device = torch_directml.device()
|
18 |
+
# dtype = torch.float16
|
19 |
return device
|
20 |
|
21 |
def loadGenerator(device):
|
|
|
33 |
generator.tokenizer.save_pretrained("cache")
|
34 |
generator.model.save_pretrained("cache")
|
35 |
del generator
|
36 |
+
# if DEVICE == "directml":
|
37 |
+
# torch_directml.empty_cache()
|
AI/sentiment_analysis.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
-
import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
-
elif DEVICE == "directml":
|
17 |
-
|
18 |
-
|
19 |
return device
|
20 |
|
21 |
def loadClassifier(device):
|
@@ -30,5 +30,5 @@ def clearCache(DEVICE, classifier):
|
|
30 |
classifier.tokenizer.save_pretrained("cache")
|
31 |
classifier.model.save_pretrained("cache")
|
32 |
del classifier
|
33 |
-
if DEVICE == "directml":
|
34 |
-
|
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
+
# import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
+
# elif DEVICE == "directml":
|
17 |
+
# device = torch_directml.device()
|
18 |
+
# dtype = torch.float16
|
19 |
return device
|
20 |
|
21 |
def loadClassifier(device):
|
|
|
30 |
classifier.tokenizer.save_pretrained("cache")
|
31 |
classifier.model.save_pretrained("cache")
|
32 |
del classifier
|
33 |
+
# if DEVICE == "directml":
|
34 |
+
# torch_directml.empty_cache()
|
AI/summarization.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
-
import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
-
elif DEVICE == "directml":
|
17 |
-
|
18 |
-
|
19 |
return device
|
20 |
|
21 |
def loadSummarizer(device):
|
@@ -30,5 +30,5 @@ def clearCache(DEVICE, summarizer):
|
|
30 |
summarizer.tokenizer.save_pretrained("cache")
|
31 |
summarizer.model.save_pretrained("cache")
|
32 |
del summarizer
|
33 |
-
if DEVICE == "directml":
|
34 |
-
|
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
+
# import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
+
# elif DEVICE == "directml":
|
17 |
+
# device = torch_directml.device()
|
18 |
+
# dtype = torch.float16
|
19 |
return device
|
20 |
|
21 |
def loadSummarizer(device):
|
|
|
30 |
summarizer.tokenizer.save_pretrained("cache")
|
31 |
summarizer.model.save_pretrained("cache")
|
32 |
del summarizer
|
33 |
+
# if DEVICE == "directml":
|
34 |
+
# torch_directml.empty_cache()
|
AI/text_generation.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
-
import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
-
elif DEVICE == "directml":
|
17 |
-
|
18 |
-
|
19 |
return device
|
20 |
|
21 |
def loadGenerator(device):
|
@@ -30,5 +30,5 @@ def clearCache(DEVICE, generator):
|
|
30 |
generator.tokenizer.save_pretrained("cache")
|
31 |
generator.model.save_pretrained("cache")
|
32 |
del generator
|
33 |
-
if DEVICE == "directml":
|
34 |
-
|
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
+
# import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
+
# elif DEVICE == "directml":
|
17 |
+
# device = torch_directml.device()
|
18 |
+
# dtype = torch.float16
|
19 |
return device
|
20 |
|
21 |
def loadGenerator(device):
|
|
|
30 |
generator.tokenizer.save_pretrained("cache")
|
31 |
generator.model.save_pretrained("cache")
|
32 |
del generator
|
33 |
+
# if DEVICE == "directml":
|
34 |
+
# torch_directml.empty_cache()
|
AI/zero_shot_classification.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
-
import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
-
elif DEVICE == "directml":
|
17 |
-
|
18 |
-
|
19 |
return device
|
20 |
|
21 |
def loadGenerator(device):
|
@@ -30,5 +30,5 @@ def clearCache(DEVICE, generator):
|
|
30 |
generator.tokenizer.save_pretrained("cache")
|
31 |
generator.model.save_pretrained("cache")
|
32 |
del generator
|
33 |
-
if DEVICE == "directml":
|
34 |
-
|
|
|
2 |
# Creation date: 2024-01-10
|
3 |
|
4 |
import torch
|
5 |
+
# import torch_directml
|
6 |
from transformers import pipeline
|
7 |
|
8 |
def getDevice(DEVICE):
|
|
|
13 |
elif DEVICE == "cuda":
|
14 |
device = torch.device("cuda")
|
15 |
dtype = torch.float16
|
16 |
+
# elif DEVICE == "directml":
|
17 |
+
# device = torch_directml.device()
|
18 |
+
# dtype = torch.float16
|
19 |
return device
|
20 |
|
21 |
def loadGenerator(device):
|
|
|
30 |
generator.tokenizer.save_pretrained("cache")
|
31 |
generator.model.save_pretrained("cache")
|
32 |
del generator
|
33 |
+
# if DEVICE == "directml":
|
34 |
+
# torch_directml.empty_cache()
|
requirements.txt
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
torch
|
2 |
-
torch_directml
|
3 |
transformers
|
|
|
1 |
torch
|
|
|
2 |
transformers
|