wangjin2000 commited on
Commit
1821c6c
·
verified ·
1 Parent(s): 5c47379

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -72,16 +72,16 @@ def compute_loss(model, inputs):
72
  return loss
73
 
74
  # Load the data from pickle files (replace with your local paths)
75
- with open("/datasets/AmelieSchreiber/binding_sites_random_split_by_family/train_sequences_chunked_by_family.pkl", "rb") as f:
76
  train_sequences = pickle.load(f)
77
 
78
- with open("/datasets/AmelieSchreiber/binding_sites_random_split_by_family/test_sequences_chunked_by_family.pkl", "rb") as f:
79
  test_sequences = pickle.load(f)
80
 
81
- with open("/datasets/AmelieSchreiber/binding_sites_random_split_by_family/train_labels_chunked_by_family.pkl", "rb") as f:
82
  train_labels = pickle.load(f)
83
 
84
- with open("/datasets/AmelieSchreiber/binding_sites_random_split_by_family/test_labels_chunked_by_family.pkl", "rb") as f:
85
  test_labels = pickle.load(f)
86
 
87
  # Tokenization
@@ -104,3 +104,11 @@ flat_train_labels = [label for sublist in train_labels for label in sublist]
104
  class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
105
  accelerator = Accelerator()
106
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
 
 
 
 
 
 
 
 
 
72
  return loss
73
 
74
  # Load the data from pickle files (replace with your local paths)
75
+ with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
76
  train_sequences = pickle.load(f)
77
 
78
+ with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f:
79
  test_sequences = pickle.load(f)
80
 
81
+ with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f:
82
  train_labels = pickle.load(f)
83
 
84
+ with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
85
  test_labels = pickle.load(f)
86
 
87
  # Tokenization
 
104
  class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
105
  accelerator = Accelerator()
106
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
107
+
108
+ dubug_result = class_weights
109
+ demo = gr.Blocks(title="DEMO FOR ESMBind")
110
+
111
+ with demo:
112
+ gr.Markdown("# DEMO FOR ESMBind")
113
+ gr.Textbox(dubug_result)
114
+ demo.launch()