Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -295,6 +295,10 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
|
|
295 |
accelerator = Accelerator()
|
296 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
297 |
|
|
|
|
|
|
|
|
|
298 |
'''
|
299 |
# inference
|
300 |
# Path to the saved LoRA model
|
@@ -323,14 +327,13 @@ with torch.no_grad():
|
|
323 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
324 |
predictions = torch.argmax(logits, dim=2)
|
325 |
|
326 |
-
|
327 |
# Define labels
|
328 |
id2label = {
|
329 |
0: "No binding site",
|
330 |
1: "Binding site"
|
331 |
}
|
332 |
|
333 |
-
'''
|
334 |
# Print the predicted labels for each token
|
335 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
336 |
if token not in ['<pad>', '<cls>', '<eos>']:
|
|
|
295 |
accelerator = Accelerator()
|
296 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
297 |
|
298 |
+
# Define labels and model
|
299 |
+
id2label = {0: "No binding site", 1: "Binding site"}
|
300 |
+
label2id = {v: k for k, v in id2label.items()}
|
301 |
+
|
302 |
'''
|
303 |
# inference
|
304 |
# Path to the saved LoRA model
|
|
|
327 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
328 |
predictions = torch.argmax(logits, dim=2)
|
329 |
|
330 |
+
|
331 |
# Define labels
|
332 |
id2label = {
|
333 |
0: "No binding site",
|
334 |
1: "Binding site"
|
335 |
}
|
336 |
|
|
|
337 |
# Print the predicted labels for each token
|
338 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
339 |
if token not in ['<pad>', '<cls>', '<eos>']:
|