Spaces:
Running
Running
File size: 10,745 Bytes
87c3140 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
from __future__ import annotations
import os, yaml, shutil
from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile
import urllib.request
from tqdm import tqdm
import subprocess
VERSION = 'v-2-1'
def fetch_data(logger, dir_home, cfg_file_path):
logger.name = 'Fetch Data'
ready_to_use = False
do_fetch = True
current = ''.join(['release_', VERSION])
# Make sure weights are present
if os.path.isfile(os.path.join(dir_home,'bin','version.yml')):
ver = load_version(dir_home)
if ver['version'] == VERSION:
if current in os.listdir(os.path.join(dir_home,'bin')): # The release dir is present
do_fetch = False
ready_to_use = True
logger.warning(f"Version file --- {os.path.join(dir_home,'bin','version.yml')}")
logger.warning(f"Current version --- {ver['version']}")
logger.warning(f"Last updated --- {ver['last_update']}")
else: # right version, no release dir yet
do_fetch = True
logger.warning(f"--------------------------------")
logger.warning(f" Downloading data files... ")
logger.warning(f"--------------------------------")
logger.warning(f"Version file --- {os.path.join(dir_home,'bin','version.yml')}")
logger.warning(f"Current version --- {ver['version']}")
logger.warning(f"Last updated --- {ver['last_update']}")
else:
do_fetch = True
logger.warning(f"--------------------------------")
logger.warning(f" Out of date... ")
logger.warning(f" Downloading data files... ")
logger.warning(f"--------------------------------")
logger.warning(f"Version file --- {os.path.join(dir_home,'bin','version.yml')}")
logger.warning(f"Current version --- {ver['version']}")
logger.warning(f"Last updated --- {ver['last_update']}")
else:
do_fetch = True
logger.warning(f"--------------------------------")
logger.warning(f" Missing version.yml... ")
logger.warning(f" Downloading data files... ")
logger.warning(f"--------------------------------")
logger.warning(f"Version file --- {os.path.join(dir_home,'bin','version.yml')}")
logger.warning(f"Current version --- {ver['version']}")
logger.warning(f"Last updated --- {ver['last_update']}")
if do_fetch:
logger.warning(f"Fetching files for version --> {ver['version']}")
path_release = get_weights(dir_home, current, logger)
if path_release is not None:
logger.warning(f"Data download successful. Unzipping...")
move_data_to_home(path_release, dir_home, logger)
ready_to_use = True
logger.warning(f"--------------------------------")
logger.warning(f" LeafMachine2 is up to date ")
logger.warning(f"--------------------------------")
else:
logger.warning(f"--------------------------------")
logger.warning(f" LeafMachine2 is up to date ")
logger.warning(f"--------------------------------")
return ready_to_use
def get_weights(dir_home, current, logger):
try:
path_zip = os.path.join(dir_home,'bin',current)
zipurl = ''.join(['https://leafmachine.org/LM2/', current,'.zip'])
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36'}
req = urllib.request.Request(url=zipurl, headers=headers)
# Get the file size from the Content-Length header
with urllib.request.urlopen(req) as url_response:
file_size = int(url_response.headers['Content-Length'])
# Download the ZIP file from the URL with progress bar
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, total=file_size) as pbar:
with urllib.request.urlopen(req) as url_response:
with open(current + '.zip', 'wb') as file:
while True:
chunk = url_response.read(4096)
if not chunk:
break
file.write(chunk)
pbar.update(len(chunk))
# Extract the contents of the ZIP file to the current directory
zipfilename = current + '.zip'
with ZipFile(zipfilename, 'r') as zip_file:
zip_file.extractall(os.path.join(dir_home,'bin'))
print(f"{bcolors.CGREENBG2}Data extracted to {path_zip}{bcolors.ENDC}")
logger.warning(f"Data extracted to {path_zip}")
return path_zip
except Exception as e:
print(f"{bcolors.CREDBG2}ERROR --- Could not download or extract machine learning models\n{e}{bcolors.ENDC}")
logger.warning(f"ERROR --- Could not download or extract machine learning models")
logger.warning(f"ERROR --- {e}")
return None
def load_version(dir_home):
try:
with open(os.path.join(dir_home,'bin',"version.yml"), "r") as ymlfile:
ver = yaml.full_load(ymlfile)
except:
with open(os.path.join(os.path.dirname(os.path.dirname(dir_home)),'bin',"version.yml"), "r") as ymlfile:
ver = yaml.full_load(ymlfile)
return ver
def move_data_to_home(path_release, dir_home, logger):
path_list_file = os.path.join(path_release, 'path_list.yml')
with open(path_list_file, 'r') as file:
path_list = yaml.safe_load(file)
paths = {
'path_ruler_classifier': os.path.join(dir_home, *path_list['path_ruler_classifier'].split('___')),
'path_ruler_binary_classifier': os.path.join(dir_home, *path_list['path_ruler_binary_classifier'].split('___')),
'path_ruler_classifier_binary_classes': os.path.join(dir_home, *path_list['path_ruler_classifier_binary_classes'].split('___')),
'path_ruler_classifier_ruler_classes': os.path.join(dir_home, *path_list['path_ruler_classifier_ruler_classes'].split('___')),
'path_DocEnTR': os.path.join(dir_home, *path_list['path_DocEnTR'].split('___')),
'path_ACD': os.path.join(dir_home, *path_list['path_ACD'].split('___')),
'path_PCD': os.path.join(dir_home, *path_list['path_PCD'].split('___')),
'path_landmarks': os.path.join(dir_home, *path_list['path_landmarks'].split('___')),
'path_YOLO': os.path.join(dir_home, *path_list['path_YOLO'].split('___')),
'path_segment': os.path.join(dir_home, *path_list['path_segment'].split('___')),
}
### Ruler classifier
source_file = os.path.join(path_release, 'ruler_classifier', 'ruler_classifier_38classes_v-1.pt')
destination_dir = paths['path_ruler_classifier']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
source_file = os.path.join(path_release, 'ruler_classifier', 'model_scripted_resnet_720_withCompression.pt')
destination_dir = paths['path_ruler_binary_classifier']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
source_file = os.path.join(path_release, 'ruler_classifier', 'binary_classes.txt')
destination_dir = paths['path_ruler_classifier_binary_classes']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
source_file = os.path.join(path_release, 'ruler_classifier', 'ruler_classes.txt')
destination_dir = paths['path_ruler_classifier_ruler_classes']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
### Ruler segmentation
source_file = os.path.join(path_release, 'ruler_segment', 'small_256_8__epoch-10.pt')
destination_dir = paths['path_DocEnTR']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
### ACD
source_file = os.path.join(path_release, 'acd', 'best.pt')
destination_dir = paths['path_ACD']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
### PCD
source_file = os.path.join(path_release, 'pcd', 'best.pt')
destination_dir = paths['path_PCD']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
### Landmarks
source_file = os.path.join(path_release, 'landmarks', 'best.pt')
destination_dir = paths['path_landmarks']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
### YOLO
source_file = os.path.join(path_release, 'YOLO', 'yolov5x6.pt')
destination_dir = paths['path_YOLO']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
### Segmentation
source_file = os.path.join(path_release, 'segmentation', 'model_final.pth')
destination_dir = paths['path_segment']
os.makedirs(destination_dir, exist_ok=True)
try_move(logger, source_file, destination_dir )
def git_pull_no_clean():
# Define the git command to run
git_cmd = ['git', 'pull', '--no-clean']
# Run the git command using subprocess
result = subprocess.run(git_cmd, capture_output=True, text=True)
# Check if the command was successful
if result.returncode == 0:
print(result.stdout)
else:
print(result.stderr)
def try_move(logger, source_file, destination_dir ):
try:
# Try to move the file using shutil.move()
shutil.move(source_file, destination_dir, copy_function=shutil.copy2)
logger.warning(f"{source_file}\nmoved to\n{destination_dir}")
print(f"{source_file}\nmoved to\n{destination_dir}")
except FileExistsError:
# If the file already exists in the destination directory, skip it
logger.warning(f"Already exists in\n{destination_dir}.\nSkipping...")
print(f"Already exists in\n{destination_dir}.\nSkipping...")
pass
except Exception as e:
# Catch any other exceptions that might occur
logger.warning(f"[ERROR] occurred while moving:\n{source_file}:\n{str(e)}")
print(f"ERROR occurred while moving:\n{source_file}:\n{str(e)}")
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
CGREENBG2 = '\33[102m'
CREDBG2 = '\33[101m'
CWHITEBG2 = '\33[107m'
if __name__ == '__main__':
git_pull_no_clean() |