Spaces:
Sleeping
Sleeping
Commit ·
6b2967c
1
Parent(s): ce9adef
fix: updated backend to graphcodebert architecture
Browse files- backend/main.py +2 -2
backend/main.py
CHANGED
|
@@ -35,7 +35,7 @@ def load_resources():
|
|
| 35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
|
| 37 |
# Load tokenizer
|
| 38 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/
|
| 39 |
|
| 40 |
# Load label encoder
|
| 41 |
if os.path.exists("label_encoder.pkl"):
|
|
@@ -44,7 +44,7 @@ def load_resources():
|
|
| 44 |
print("WARNING: label_encoder.pkl not found!")
|
| 45 |
|
| 46 |
# Load model
|
| 47 |
-
model = AutoModelForSequenceClassification.from_pretrained("microsoft/
|
| 48 |
if os.path.exists("best_model.pt"):
|
| 49 |
model.load_state_dict(torch.load("best_model.pt", map_location=device))
|
| 50 |
else:
|
|
|
|
| 35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
|
| 37 |
# Load tokenizer
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
|
| 39 |
|
| 40 |
# Load label encoder
|
| 41 |
if os.path.exists("label_encoder.pkl"):
|
|
|
|
| 44 |
print("WARNING: label_encoder.pkl not found!")
|
| 45 |
|
| 46 |
# Load model
|
| 47 |
+
model = AutoModelForSequenceClassification.from_pretrained("microsoft/graphcodebert-base", num_labels=7)
|
| 48 |
if os.path.exists("best_model.pt"):
|
| 49 |
model.load_state_dict(torch.load("best_model.pt", map_location=device))
|
| 50 |
else:
|