Sadjad Alikhani commited on
Commit
2587718
·
verified ·
1 Parent(s): 56460f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -34,20 +34,45 @@ from transformers import AutoModel # Assuming you use a transformer-like model
34
  import numpy as np
35
  import importlib.util
36
 
37
- # Function to load the pre-trained model from Hugging Face
38
- def load_pretrained_model():
39
- # Load the pre-trained model from the Hugging Face repo
40
- model = AutoModel.from_pretrained("sadjadalikhani/LWM")
41
- model.eval() # Set model to evaluation mode
 
 
 
 
 
 
 
42
  return model
43
 
44
- # Function to process the uploaded .py file and perform inference using the model
45
  def process_python_file(uploaded_file, percentage_idx, complexity_idx):
46
  try:
47
- # Step 1: Load the model
48
- model = load_pretrained_model()
 
 
 
 
 
49
 
50
- # Step 2: Load the uploaded .py file that contains the wireless channel matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Import the Python file dynamically
52
  spec = importlib.util.spec_from_file_location("uploaded_module", uploaded_file.name)
53
  uploaded_module = importlib.util.module_from_spec(spec)
@@ -56,13 +81,15 @@ def process_python_file(uploaded_file, percentage_idx, complexity_idx):
56
  # Assuming the uploaded file defines a variable called 'channel_matrix'
57
  channel_matrix = uploaded_module.channel_matrix # This should be defined in the uploaded file
58
 
59
- # Step 3: Perform inference on the channel matrix using the model
 
 
 
60
  with torch.no_grad():
61
- input_tensor = torch.tensor(channel_matrix).unsqueeze(0) # Add batch dimension
62
  output = model(input_tensor) # Perform inference
63
 
64
- # Step 4: Generate new images based on the inference results
65
- # You can modify this logic depending on how you want to visualize the results
66
  generated_raw_img = np.random.rand(300, 300, 3) * 255 # Placeholder: Replace with actual inference result
67
  generated_embeddings_img = np.random.rand(300, 300, 3) * 255 # Placeholder: Replace with actual inference result
68
 
@@ -82,7 +109,6 @@ def process_python_file(uploaded_file, percentage_idx, complexity_idx):
82
  except Exception as e:
83
  return str(e), str(e)
84
 
85
-
86
  # Function to handle logic based on whether a file is uploaded or not
87
  def los_nlos_classification(file, percentage_idx, complexity_idx):
88
  if file is not None:
 
34
  import numpy as np
35
  import importlib.util
36
 
37
+ import torch
38
+ import numpy as np
39
+ import importlib.util
40
+ import subprocess
41
+ import os
42
+
43
+ # Function to load the pre-trained model from your cloned repository
44
+ def load_custom_model():
45
+ # Assume your model is in the cloned LWM repository
46
+ from lwm_model import LWM # Assuming the model is defined in lwm_model.py
47
+ model = LWM() # Modify this according to your model initialization
48
+ model.eval() # Set the model to evaluation mode
49
  return model
50
 
51
+ # Function to process the uploaded .py file and perform inference using the custom model
52
  def process_python_file(uploaded_file, percentage_idx, complexity_idx):
53
  try:
54
+ # Clone the repository if not already done (for model and tokenizer)
55
+ model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
56
+ model_repo_dir = "./LWM"
57
+
58
+ if not os.path.exists(model_repo_dir):
59
+ print(f"Cloning model repository from {model_repo_url}...")
60
+ subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
61
 
62
+ # Change the working directory to the cloned LWM folder
63
+ if os.path.exists(model_repo_dir):
64
+ os.chdir(model_repo_dir)
65
+ print(f"Changed working directory to {os.getcwd()}")
66
+ else:
67
+ return f"Directory {model_repo_dir} does not exist."
68
+
69
+ # Step 1: Load the custom model
70
+ model = load_custom_model()
71
+
72
+ # Step 2: Import the tokenizer
73
+ from input_preprocess import tokenizer
74
+
75
+ # Step 3: Load the uploaded .py file that contains the wireless channel matrix
76
  # Import the Python file dynamically
77
  spec = importlib.util.spec_from_file_location("uploaded_module", uploaded_file.name)
78
  uploaded_module = importlib.util.module_from_spec(spec)
 
81
  # Assuming the uploaded file defines a variable called 'channel_matrix'
82
  channel_matrix = uploaded_module.channel_matrix # This should be defined in the uploaded file
83
 
84
+ # Step 4: Tokenize the data if needed (or perform any necessary preprocessing)
85
+ preprocessed_data = tokenizer(manual_data=channel_matrix, gen_raw=True)
86
+
87
+ # Step 5: Perform inference on the channel matrix using the model
88
  with torch.no_grad():
89
+ input_tensor = torch.tensor(preprocessed_data).unsqueeze(0) # Add batch dimension
90
  output = model(input_tensor) # Perform inference
91
 
92
+ # Step 6: Generate new images based on the inference results
 
93
  generated_raw_img = np.random.rand(300, 300, 3) * 255 # Placeholder: Replace with actual inference result
94
  generated_embeddings_img = np.random.rand(300, 300, 3) * 255 # Placeholder: Replace with actual inference result
95
 
 
109
  except Exception as e:
110
  return str(e), str(e)
111
 
 
112
  # Function to handle logic based on whether a file is uploaded or not
113
  def los_nlos_classification(file, percentage_idx, complexity_idx):
114
  if file is not None: