Sadjad Alikhani commited on
Commit
aa5e7da
·
verified ·
1 Parent(s): 058904f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -9
app.py CHANGED
@@ -50,6 +50,13 @@ def load_custom_model():
50
  model.eval()
51
  return model
52
 
 
 
 
 
 
 
 
53
  # Function to process the uploaded .p file and perform inference using the custom model
54
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
55
  capture = PrintCapture()
@@ -59,32 +66,58 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
59
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
60
  model_repo_dir = "./LWM"
61
 
 
62
  if not os.path.exists(model_repo_dir):
63
  print(f"Cloning model repository from {model_repo_url}...")
64
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
65
-
 
66
  if os.path.exists(model_repo_dir):
67
  os.chdir(model_repo_dir)
68
  print(f"Changed working directory to {os.getcwd()}")
 
69
  else:
70
- return f"Directory {model_repo_dir} does not exist."
 
71
 
72
- # Add LWM repo path to Python module search path
73
  if model_repo_dir not in sys.path:
74
  sys.path.append(model_repo_dir)
75
-
76
- from lwm_model import LWM # Now this should work
77
- device = 'cpu'
78
- print(f"Loading the LWM model on {device}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  model = LWM.from_pretrained(device=device)
80
 
81
- from input_preprocess import tokenizer
 
 
 
 
 
82
 
 
83
  with open(uploaded_file.name, 'rb') as f:
84
  manual_data = pickle.load(f)
85
 
 
86
  preprocessed_chs = tokenizer(manual_data=manual_data)
87
 
 
88
  from inference import lwm_inference, create_raw_dataset
89
  output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
90
  output_raw = create_raw_dataset(preprocessed_chs, device)
@@ -92,13 +125,15 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
92
  print(f"Output Embeddings Shape: {output_emb.shape}")
93
  print(f"Output Raw Shape: {output_raw.shape}")
94
 
 
95
  return output_emb, output_raw, capture.get_output()
96
 
97
  except Exception as e:
 
98
  return str(e), str(e), capture.get_output()
99
 
100
  finally:
101
- sys.stdout = sys.__stdout__ # Reset print statements
102
 
103
  # Function to handle logic based on whether a file is uploaded or not
104
  def los_nlos_classification(file, percentage_idx, complexity_idx):
 
50
  model.eval()
51
  return model
52
 
53
+ import sys
54
+ import subprocess
55
+ import os
56
+ import pickle
57
+ import torch
58
+ import io
59
+
60
  # Function to process the uploaded .p file and perform inference using the custom model
61
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
62
  capture = PrintCapture()
 
66
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
67
  model_repo_dir = "./LWM"
68
 
69
+ # Step 1: Clone the model repository if not already cloned
70
  if not os.path.exists(model_repo_dir):
71
  print(f"Cloning model repository from {model_repo_url}...")
72
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
73
+
74
+ # Debugging: Check if the directory exists and print contents
75
  if os.path.exists(model_repo_dir):
76
  os.chdir(model_repo_dir)
77
  print(f"Changed working directory to {os.getcwd()}")
78
+ print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
79
  else:
80
+ print(f"Directory {model_repo_dir} does not exist.")
81
+ return
82
 
83
+ # Step 2: Add the cloned repo to sys.path for imports
84
  if model_repo_dir not in sys.path:
85
  sys.path.append(model_repo_dir)
86
+
87
+ # Debugging: Print sys.path to ensure the cloned repo is in the path
88
+ print(f"sys.path: {sys.path}")
89
+
90
+ # Step 3: Dynamically import the model after cloning
91
+ try:
92
+ from lwm_model import LWM # Custom model in the cloned repo
93
+ print("Successfully imported LWM model.")
94
+ except ImportError as e:
95
+ print(f"Error importing LWM model: {e}")
96
+ print("Make sure lwm_model.py exists in the cloned repository.")
97
+ return
98
+
99
+ # Step 4: Check if GPU is available and set the device
100
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
101
+ print(f"Using device: {device}")
102
+
103
+ # Load the model from the cloned repository
104
  model = LWM.from_pretrained(device=device)
105
 
106
+ # Step 5: Import the tokenizer
107
+ try:
108
+ from input_preprocess import tokenizer
109
+ except ImportError as e:
110
+ print(f"Error importing tokenizer: {e}")
111
+ return
112
 
113
+ # Step 6: Load the uploaded .p file (wireless channel matrix)
114
  with open(uploaded_file.name, 'rb') as f:
115
  manual_data = pickle.load(f)
116
 
117
+ # Step 7: Tokenize the data if needed (or perform any necessary preprocessing)
118
  preprocessed_chs = tokenizer(manual_data=manual_data)
119
 
120
+ # Step 8: Perform inference using the model
121
  from inference import lwm_inference, create_raw_dataset
122
  output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
123
  output_raw = create_raw_dataset(preprocessed_chs, device)
 
125
  print(f"Output Embeddings Shape: {output_emb.shape}")
126
  print(f"Output Raw Shape: {output_raw.shape}")
127
 
128
+ # Return the embeddings, raw output, and captured output
129
  return output_emb, output_raw, capture.get_output()
130
 
131
  except Exception as e:
132
+ # Handle exceptions and return the captured output
133
  return str(e), str(e), capture.get_output()
134
 
135
  finally:
136
+ sys.stdout = sys.__stdout__ # Reset stdout
137
 
138
  # Function to handle logic based on whether a file is uploaded or not
139
  def los_nlos_classification(file, percentage_idx, complexity_idx):