Sadjad Alikhani commited on
Commit
d2d9264
·
verified ·
1 Parent(s): 1bbecff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -40
app.py CHANGED
@@ -54,63 +54,35 @@ def load_custom_model():
54
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
55
  capture = PrintCapture()
56
  sys.stdout = capture # Redirect print statements to capture
57
-
58
  try:
59
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
60
  model_repo_dir = "./LWM"
61
 
62
- # Step 1: Clone the model repository if not already cloned
63
  if not os.path.exists(model_repo_dir):
64
  print(f"Cloning model repository from {model_repo_url}...")
65
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
66
-
67
- # Debugging: Check if the directory exists and print contents
68
  if os.path.exists(model_repo_dir):
69
  os.chdir(model_repo_dir)
70
  print(f"Changed working directory to {os.getcwd()}")
71
- print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
72
  else:
73
- print(f"Directory {model_repo_dir} does not exist.")
74
- return
75
 
76
- # Step 2: Add the cloned repo to sys.path for imports
77
- if model_repo_dir not in sys.path:
78
- sys.path.append(model_repo_dir)
79
-
80
- # Debugging: Print sys.path to ensure the cloned repo is in the path
81
- print(f"sys.path: {sys.path}")
82
-
83
- # Step 3: Dynamically import the model after cloning
84
- try:
85
- from lwm_model import LWM # Custom model in the cloned repo
86
- print("Successfully imported LWM model.")
87
- except ImportError as e:
88
- print(f"Error importing LWM model: {e}")
89
- print("Make sure lwm_model.py exists in the cloned repository.")
90
- return
91
-
92
- # Step 4: Check if GPU is available and set the device
93
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
94
- print(f"Using device: {device}")
95
-
96
- # Load the model from the cloned repository
97
  model = LWM.from_pretrained(device=device)
98
 
99
- # Step 5: Import the tokenizer
100
- try:
101
- from input_preprocess import tokenizer
102
- except ImportError as e:
103
- print(f"Error importing tokenizer: {e}")
104
- return
105
 
106
- # Step 6: Load the uploaded .p file (wireless channel matrix)
107
  with open(uploaded_file.name, 'rb') as f:
108
  manual_data = pickle.load(f)
109
 
110
- # Step 7: Tokenize the data if needed (or perform any necessary preprocessing)
111
  preprocessed_chs = tokenizer(manual_data=manual_data)
112
 
113
- # Step 8: Perform inference using the model
114
  from inference import lwm_inference, create_raw_dataset
115
  output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
116
  output_raw = create_raw_dataset(preprocessed_chs, device)
@@ -118,15 +90,13 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
118
  print(f"Output Embeddings Shape: {output_emb.shape}")
119
  print(f"Output Raw Shape: {output_raw.shape}")
120
 
121
- # Return the embeddings, raw output, and captured output
122
  return output_emb, output_raw, capture.get_output()
123
 
124
  except Exception as e:
125
- # Handle exceptions and return the captured output
126
  return str(e), str(e), capture.get_output()
127
 
128
  finally:
129
- sys.stdout = sys.__stdout__ # Reset stdout
130
 
131
  # Function to handle logic based on whether a file is uploaded or not
132
  def los_nlos_classification(file, percentage_idx, complexity_idx):
 
54
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
55
  capture = PrintCapture()
56
  sys.stdout = capture # Redirect print statements to capture
57
+
58
  try:
59
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
60
  model_repo_dir = "./LWM"
61
 
 
62
  if not os.path.exists(model_repo_dir):
63
  print(f"Cloning model repository from {model_repo_url}...")
64
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
65
+
 
66
  if os.path.exists(model_repo_dir):
67
  os.chdir(model_repo_dir)
68
  print(f"Changed working directory to {os.getcwd()}")
 
69
  else:
70
+ return f"Directory {model_repo_dir} does not exist."
 
71
 
72
+ print("a")
73
+ from lwm_model import LWM
74
+ print("b")
75
+ device = 'cpu'
76
+ print(f"Loading the LWM model on {device}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  model = LWM.from_pretrained(device=device)
78
 
79
+ from input_preprocess import tokenizer
 
 
 
 
 
80
 
 
81
  with open(uploaded_file.name, 'rb') as f:
82
  manual_data = pickle.load(f)
83
 
 
84
  preprocessed_chs = tokenizer(manual_data=manual_data)
85
 
 
86
  from inference import lwm_inference, create_raw_dataset
87
  output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
88
  output_raw = create_raw_dataset(preprocessed_chs, device)
 
90
  print(f"Output Embeddings Shape: {output_emb.shape}")
91
  print(f"Output Raw Shape: {output_raw.shape}")
92
 
 
93
  return output_emb, output_raw, capture.get_output()
94
 
95
  except Exception as e:
 
96
  return str(e), str(e), capture.get_output()
97
 
98
  finally:
99
+ sys.stdout = sys.__stdout__ # Reset print statements
100
 
101
  # Function to handle logic based on whether a file is uploaded or not
102
  def los_nlos_classification(file, percentage_idx, complexity_idx):