File size: 5,493 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import logging
import os
from typing import TYPE_CHECKING, Optional
import wandb
from wandb.sdk.lib.paths import LogicalPath
if TYPE_CHECKING:
from wandb.filesync import dir_watcher, stats, step_upload
from wandb.sdk.internal import file_stream, internal_api
logger = logging.getLogger(__name__)
class UploadJob:
def __init__(
self,
stats: "stats.Stats",
api: "internal_api.Api",
file_stream: "file_stream.FileStreamApi",
silent: bool,
save_name: LogicalPath,
path: "dir_watcher.PathStr",
artifact_id: Optional[str],
md5: Optional[str],
copied: bool,
save_fn: Optional["step_upload.SaveFn"],
digest: Optional[str],
) -> None:
"""A file uploader.
Args:
push_function: function(save_name, actual_path) which actually uploads
the file.
save_name: string logical location of the file relative to the run
directory.
path: actual string path of the file to upload on the filesystem.
"""
self._stats = stats
self._api = api
self._file_stream = file_stream
self.silent = silent
self.save_name = save_name
self.save_path = path
self.artifact_id = artifact_id
self.md5 = md5
self.copied = copied
self.save_fn = save_fn
self.digest = digest
super().__init__()
def run(self) -> None:
success = False
try:
self.push()
success = True
finally:
if self.copied and os.path.isfile(self.save_path):
os.remove(self.save_path)
if success:
self._file_stream.push_success(self.artifact_id, self.save_name) # type: ignore
def push(self) -> None:
if self.save_fn:
# Retry logic must happen in save_fn currently
try:
deduped = self.save_fn(
lambda _, t: self._stats.update_uploaded_file(self.save_path, t)
)
except Exception as e:
self._stats.update_failed_file(self.save_path)
logger.exception("Failed to upload file: %s", self.save_path)
wandb._sentry.exception(e)
message = str(e)
# TODO: this is usually XML, but could be JSON
if hasattr(e, "response"):
message = e.response.content
wandb.termerror(
f'Error uploading "{self.save_path}": {type(e).__name__}, {message}'
)
raise
if deduped:
logger.info("Skipped uploading %s", self.save_path)
self._stats.set_file_deduped(self.save_path)
else:
logger.info("Uploaded file %s", self.save_path)
return
if self.md5:
# This is the new artifact manifest upload flow, in which we create the
# database entry for the manifest file before creating it. This is used for
# artifact L0 files. Which now is only artifact_manifest.json
_, response = self._api.create_artifact_manifest(
self.save_name, self.md5, self.artifact_id
)
upload_url = response["uploadUrl"]
upload_headers = response["uploadHeaders"]
else:
# The classic file upload flow. We get a signed url and upload the file
# then the backend handles the cloud storage metadata callback to create the
# file entry. This flow has aged like a fine wine.
project = self._api.get_project()
_, upload_headers, result = self._api.upload_urls(project, [self.save_name])
file_info = result[self.save_name]
upload_url = file_info["uploadUrl"]
if upload_url is None:
logger.info("Skipped uploading %s", self.save_path)
self._stats.set_file_deduped(self.save_name)
else:
extra_headers = self._api._extra_http_headers
for upload_header in upload_headers:
key, val = upload_header.split(":", 1)
extra_headers[key] = val
# Copied from push TODO(artifacts): clean up
# If the upload URL is relative, fill it in with the base URL,
# since its a proxied file store like the on-prem VM.
if upload_url.startswith("/"):
upload_url = f"{self._api.api_url}{upload_url}"
try:
with open(self.save_path, "rb") as f:
self._api.upload_file_retry(
upload_url,
f,
lambda _, t: self.progress(t),
extra_headers=extra_headers,
)
logger.info("Uploaded file %s", self.save_path)
except Exception as e:
self._stats.update_failed_file(self.save_name)
logger.exception("Failed to upload file: %s", self.save_path)
wandb._sentry.exception(e)
if not self.silent:
wandb.termerror(
f'Error uploading "{self.save_name}": {type(e).__name__}, {e}'
)
raise
def progress(self, total_bytes: int) -> None:
self._stats.update_uploaded_file(self.save_name, total_bytes)
|