|
import { createApiError } from "../error"; |
|
import type { CredentialsParams } from "../types/public"; |
|
import { checkCredentials } from "./checkCredentials"; |
|
import { decompress as lz4_decompress } from "../vendor/lz4js"; |
|
import { RangeList } from "./RangeList"; |
|
|
|
const JWT_SAFETY_PERIOD = 60_000; |
|
const JWT_CACHE_SIZE = 1_000; |
|
|
|
type XetBlobCreateOptions = { |
|
|
|
|
|
|
|
fetch?: typeof fetch; |
|
|
|
refreshUrl: string; |
|
size: number; |
|
listener?: (arg: { event: "read" } | { event: "progress"; progress: { read: number; total: number } }) => void; |
|
internalLogging?: boolean; |
|
} & ({ hash: string; reconstructionUrl?: string } | { hash?: string; reconstructionUrl: string }) & |
|
Partial<CredentialsParams>; |
|
|
|
export interface ReconstructionInfo { |
|
|
|
|
|
|
|
terms: Array<{ |
|
|
|
hash: string; |
|
|
|
unpacked_length: number; |
|
|
|
range: { start: number; end: number }; |
|
}>; |
|
|
|
|
|
|
|
|
|
fetch_info: Record< |
|
string, |
|
Array<{ |
|
url: string; |
|
|
|
range: { start: number; end: number }; |
|
|
|
|
|
|
|
|
|
|
|
url_range: { start: number; end: number }; |
|
}> |
|
>; |
|
|
|
|
|
|
|
offset_into_first_range: number; |
|
} |
|
|
|
enum CompressionScheme { |
|
None = 0, |
|
LZ4 = 1, |
|
ByteGroupingLZ4 = 2, |
|
} |
|
|
|
const compressionSchemeLabels: Record<CompressionScheme, string> = { |
|
[CompressionScheme.None]: "None", |
|
[CompressionScheme.LZ4]: "LZ4", |
|
[CompressionScheme.ByteGroupingLZ4]: "ByteGroupingLZ4", |
|
}; |
|
|
|
interface ChunkHeader { |
|
version: number; |
|
compressed_length: number; |
|
compression_scheme: CompressionScheme; |
|
uncompressed_length: number; |
|
} |
|
|
|
const CHUNK_HEADER_BYTES = 8; |
|
|
|
|
|
|
|
|
|
export class XetBlob extends Blob { |
|
fetch: typeof fetch; |
|
accessToken?: string; |
|
refreshUrl: string; |
|
reconstructionUrl?: string; |
|
hash?: string; |
|
start = 0; |
|
end = 0; |
|
internalLogging = false; |
|
reconstructionInfo: ReconstructionInfo | undefined; |
|
listener: XetBlobCreateOptions["listener"]; |
|
|
|
constructor(params: XetBlobCreateOptions) { |
|
super([]); |
|
|
|
this.fetch = params.fetch ?? fetch.bind(globalThis); |
|
this.accessToken = checkCredentials(params); |
|
this.refreshUrl = params.refreshUrl; |
|
this.end = params.size; |
|
this.reconstructionUrl = params.reconstructionUrl; |
|
this.hash = params.hash; |
|
this.listener = params.listener; |
|
this.internalLogging = params.internalLogging ?? false; |
|
this.refreshUrl; |
|
} |
|
|
|
override get size(): number { |
|
return this.end - this.start; |
|
} |
|
|
|
#clone() { |
|
const blob = new XetBlob({ |
|
fetch: this.fetch, |
|
hash: this.hash, |
|
refreshUrl: this.refreshUrl, |
|
|
|
reconstructionUrl: this.reconstructionUrl!, |
|
size: this.size, |
|
}); |
|
|
|
blob.accessToken = this.accessToken; |
|
blob.start = this.start; |
|
blob.end = this.end; |
|
blob.reconstructionInfo = this.reconstructionInfo; |
|
blob.listener = this.listener; |
|
blob.internalLogging = this.internalLogging; |
|
|
|
return blob; |
|
} |
|
|
|
override slice(start = 0, end = this.size): XetBlob { |
|
if (start < 0 || end < 0) { |
|
new TypeError("Unsupported negative start/end on XetBlob.slice"); |
|
} |
|
|
|
const slice = this.#clone(); |
|
|
|
slice.start = this.start + start; |
|
slice.end = Math.min(this.start + end, this.end); |
|
|
|
if (slice.start !== this.start || slice.end !== this.end) { |
|
slice.reconstructionInfo = undefined; |
|
} |
|
|
|
return slice; |
|
} |
|
|
|
#reconstructionInfoPromise?: Promise<ReconstructionInfo>; |
|
|
|
#loadReconstructionInfo() { |
|
if (this.#reconstructionInfoPromise) { |
|
return this.#reconstructionInfoPromise; |
|
} |
|
|
|
this.#reconstructionInfoPromise = (async () => { |
|
const connParams = await getAccessToken(this.accessToken, this.fetch, this.refreshUrl); |
|
|
|
|
|
|
|
|
|
|
|
const resp = await this.fetch(this.reconstructionUrl ?? `${connParams.casUrl}/reconstruction/${this.hash}`, { |
|
headers: { |
|
Authorization: `Bearer ${connParams.accessToken}`, |
|
Range: `bytes=${this.start}-${this.end - 1}`, |
|
}, |
|
}); |
|
|
|
if (!resp.ok) { |
|
throw await createApiError(resp); |
|
} |
|
|
|
this.reconstructionInfo = (await resp.json()) as ReconstructionInfo; |
|
|
|
return this.reconstructionInfo; |
|
})().finally(() => (this.#reconstructionInfoPromise = undefined)); |
|
|
|
return this.#reconstructionInfoPromise; |
|
} |
|
|
|
async #fetch(): Promise<ReadableStream<Uint8Array>> { |
|
if (!this.reconstructionInfo) { |
|
await this.#loadReconstructionInfo(); |
|
} |
|
|
|
const rangeLists = new Map<string, RangeList<Uint8Array[]>>(); |
|
|
|
if (!this.reconstructionInfo) { |
|
throw new Error("Failed to load reconstruction info"); |
|
} |
|
|
|
for (const term of this.reconstructionInfo.terms) { |
|
let rangeList = rangeLists.get(term.hash); |
|
if (!rangeList) { |
|
rangeList = new RangeList<Uint8Array[]>(); |
|
rangeLists.set(term.hash, rangeList); |
|
} |
|
|
|
rangeList.add(term.range.start, term.range.end); |
|
} |
|
const listener = this.listener; |
|
const log = this.internalLogging ? (...args: unknown[]) => console.log(...args) : () => {}; |
|
|
|
async function* readData( |
|
reconstructionInfo: ReconstructionInfo, |
|
customFetch: typeof fetch, |
|
maxBytes: number, |
|
reloadReconstructionInfo: () => Promise<ReconstructionInfo> |
|
) { |
|
let totalBytesRead = 0; |
|
let readBytesToSkip = reconstructionInfo.offset_into_first_range; |
|
|
|
for (const term of reconstructionInfo.terms) { |
|
if (totalBytesRead >= maxBytes) { |
|
break; |
|
} |
|
|
|
const rangeList = rangeLists.get(term.hash); |
|
if (!rangeList) { |
|
throw new Error(`Failed to find range list for term ${term.hash}`); |
|
} |
|
|
|
{ |
|
const termRanges = rangeList.getRanges(term.range.start, term.range.end); |
|
|
|
if (termRanges.every((range) => range.data)) { |
|
log("all data available for term", term.hash, readBytesToSkip); |
|
rangeLoop: for (const range of termRanges) { |
|
|
|
for (let chunk of range.data!) { |
|
if (readBytesToSkip) { |
|
const skipped = Math.min(readBytesToSkip, chunk.byteLength); |
|
chunk = chunk.slice(skipped); |
|
readBytesToSkip -= skipped; |
|
if (!chunk.byteLength) { |
|
continue; |
|
} |
|
} |
|
if (chunk.byteLength > maxBytes - totalBytesRead) { |
|
chunk = chunk.slice(0, maxBytes - totalBytesRead); |
|
} |
|
totalBytesRead += chunk.byteLength; |
|
|
|
|
|
yield range.refCount > 1 ? chunk.slice() : chunk; |
|
listener?.({ event: "progress", progress: { read: totalBytesRead, total: maxBytes } }); |
|
|
|
if (totalBytesRead >= maxBytes) { |
|
break rangeLoop; |
|
} |
|
} |
|
} |
|
rangeList.remove(term.range.start, term.range.end); |
|
continue; |
|
} |
|
} |
|
|
|
const fetchInfo = reconstructionInfo.fetch_info[term.hash].find( |
|
(info) => info.range.start <= term.range.start && info.range.end >= term.range.end |
|
); |
|
|
|
if (!fetchInfo) { |
|
throw new Error( |
|
`Failed to find fetch info for term ${term.hash} and range ${term.range.start}-${term.range.end}` |
|
); |
|
} |
|
|
|
log("term", term); |
|
log("fetchinfo", fetchInfo); |
|
log("readBytesToSkip", readBytesToSkip); |
|
|
|
let resp = await customFetch(fetchInfo.url, { |
|
headers: { |
|
Range: `bytes=${fetchInfo.url_range.start}-${fetchInfo.url_range.end}`, |
|
}, |
|
}); |
|
|
|
if (resp.status === 403) { |
|
|
|
reconstructionInfo = await reloadReconstructionInfo(); |
|
resp = await customFetch(fetchInfo.url, { |
|
headers: { |
|
Range: `bytes=${fetchInfo.url_range.start}-${fetchInfo.url_range.end}`, |
|
}, |
|
}); |
|
} |
|
|
|
if (!resp.ok) { |
|
throw await createApiError(resp); |
|
} |
|
|
|
log( |
|
"expected content length", |
|
resp.headers.get("content-length"), |
|
"range", |
|
fetchInfo.url_range, |
|
resp.headers.get("content-range") |
|
); |
|
|
|
const reader = resp.body?.getReader(); |
|
if (!reader) { |
|
throw new Error("Failed to get reader from response body"); |
|
} |
|
|
|
let done = false; |
|
let chunkIndex = fetchInfo.range.start; |
|
const ranges = rangeList.getRanges(fetchInfo.range.start, fetchInfo.range.end); |
|
|
|
let leftoverBytes: Uint8Array | undefined = undefined; |
|
let totalFetchBytes = 0; |
|
|
|
fetchData: while (!done && totalBytesRead < maxBytes) { |
|
const result = await reader.read(); |
|
listener?.({ event: "read" }); |
|
|
|
done = result.done; |
|
|
|
log("read", result.value?.byteLength, "bytes", "total read", totalBytesRead, "toSkip", readBytesToSkip); |
|
|
|
if (!result.value) { |
|
log("no data in result, cancelled", result); |
|
continue; |
|
} |
|
|
|
totalFetchBytes += result.value.byteLength; |
|
|
|
if (leftoverBytes) { |
|
result.value = new Uint8Array([...leftoverBytes, ...result.value]); |
|
leftoverBytes = undefined; |
|
} |
|
|
|
while (totalBytesRead < maxBytes && result.value.byteLength) { |
|
if (result.value.byteLength < 8) { |
|
|
|
leftoverBytes = result.value; |
|
continue fetchData; |
|
} |
|
|
|
const header = new DataView(result.value.buffer, result.value.byteOffset, CHUNK_HEADER_BYTES); |
|
const chunkHeader: ChunkHeader = { |
|
version: header.getUint8(0), |
|
compressed_length: header.getUint8(1) | (header.getUint8(2) << 8) | (header.getUint8(3) << 16), |
|
compression_scheme: header.getUint8(4), |
|
uncompressed_length: header.getUint8(5) | (header.getUint8(6) << 8) | (header.getUint8(7) << 16), |
|
}; |
|
|
|
log("chunk header", chunkHeader, "to skip", readBytesToSkip); |
|
|
|
if (chunkHeader.version !== 0) { |
|
throw new Error(`Unsupported chunk version ${chunkHeader.version}`); |
|
} |
|
|
|
if ( |
|
chunkHeader.compression_scheme !== CompressionScheme.None && |
|
chunkHeader.compression_scheme !== CompressionScheme.LZ4 && |
|
chunkHeader.compression_scheme !== CompressionScheme.ByteGroupingLZ4 |
|
) { |
|
throw new Error( |
|
`Unsupported compression scheme ${ |
|
compressionSchemeLabels[chunkHeader.compression_scheme] ?? chunkHeader.compression_scheme |
|
}` |
|
); |
|
} |
|
|
|
if (result.value.byteLength < chunkHeader.compressed_length + CHUNK_HEADER_BYTES) { |
|
|
|
leftoverBytes = result.value; |
|
continue fetchData; |
|
} |
|
|
|
result.value = result.value.slice(CHUNK_HEADER_BYTES); |
|
|
|
let uncompressed = |
|
chunkHeader.compression_scheme === CompressionScheme.LZ4 |
|
? lz4_decompress(result.value.slice(0, chunkHeader.compressed_length), chunkHeader.uncompressed_length) |
|
: chunkHeader.compression_scheme === CompressionScheme.ByteGroupingLZ4 |
|
? bg4_regoup_bytes( |
|
lz4_decompress( |
|
result.value.slice(0, chunkHeader.compressed_length), |
|
chunkHeader.uncompressed_length |
|
) |
|
) |
|
: result.value.slice(0, chunkHeader.compressed_length); |
|
|
|
const range = ranges.find((range) => chunkIndex >= range.start && chunkIndex < range.end); |
|
const shouldYield = chunkIndex >= term.range.start && chunkIndex < term.range.end; |
|
const minRefCountToStore = shouldYield ? 2 : 1; |
|
let stored = false; |
|
|
|
|
|
if (range && range.refCount >= minRefCountToStore) { |
|
range.data ??= []; |
|
range.data.push(uncompressed); |
|
stored = true; |
|
} |
|
|
|
if (shouldYield) { |
|
if (readBytesToSkip) { |
|
const skipped = Math.min(readBytesToSkip, uncompressed.byteLength); |
|
uncompressed = uncompressed.slice(readBytesToSkip); |
|
readBytesToSkip -= skipped; |
|
} |
|
|
|
if (uncompressed.byteLength > maxBytes - totalBytesRead) { |
|
uncompressed = uncompressed.slice(0, maxBytes - totalBytesRead); |
|
} |
|
|
|
if (uncompressed.byteLength) { |
|
log( |
|
"yield", |
|
uncompressed.byteLength, |
|
"bytes", |
|
result.value.byteLength, |
|
"total read", |
|
totalBytesRead, |
|
stored |
|
); |
|
totalBytesRead += uncompressed.byteLength; |
|
yield stored ? uncompressed.slice() : uncompressed; |
|
listener?.({ event: "progress", progress: { read: totalBytesRead, total: maxBytes } }); |
|
} |
|
} |
|
|
|
chunkIndex++; |
|
result.value = result.value.slice(chunkHeader.compressed_length); |
|
} |
|
} |
|
|
|
if ( |
|
done && |
|
totalBytesRead < maxBytes && |
|
totalFetchBytes < fetchInfo.url_range.end - fetchInfo.url_range.start + 1 |
|
) { |
|
log("done", done, "total read", totalBytesRead, maxBytes, totalFetchBytes); |
|
log("failed to fetch all data for term", term.hash); |
|
throw new Error( |
|
`Failed to fetch all data for term ${term.hash}, fetched ${totalFetchBytes} bytes out of ${ |
|
fetchInfo.url_range.end - fetchInfo.url_range.start + 1 |
|
}` |
|
); |
|
} |
|
|
|
log("done", done, "total read", totalBytesRead, maxBytes, totalFetchBytes); |
|
|
|
|
|
log("cancel reader"); |
|
await reader.cancel(); |
|
} |
|
} |
|
|
|
const iterator = readData( |
|
this.reconstructionInfo, |
|
this.fetch, |
|
this.end - this.start, |
|
this.#loadReconstructionInfo.bind(this) |
|
); |
|
|
|
|
|
return new ReadableStream<Uint8Array>( |
|
{ |
|
|
|
async pull(controller) { |
|
const result = await iterator.next(); |
|
|
|
if (result.value) { |
|
controller.enqueue(result.value); |
|
} |
|
|
|
if (result.done) { |
|
controller.close(); |
|
} |
|
}, |
|
type: "bytes", |
|
|
|
}, |
|
|
|
{ |
|
highWaterMark: 1_000, |
|
} |
|
); |
|
} |
|
|
|
override async arrayBuffer(): Promise<ArrayBuffer> { |
|
const result = await this.#fetch(); |
|
|
|
return new Response(result).arrayBuffer(); |
|
} |
|
|
|
override async text(): Promise<string> { |
|
const result = await this.#fetch(); |
|
|
|
return new Response(result).text(); |
|
} |
|
|
|
async response(): Promise<Response> { |
|
const result = await this.#fetch(); |
|
|
|
return new Response(result); |
|
} |
|
|
|
override stream(): ReturnType<Blob["stream"]> { |
|
const stream = new TransformStream(); |
|
|
|
this.#fetch() |
|
.then((response) => response.pipeThrough(stream)) |
|
.catch((error) => stream.writable.abort(error.message)); |
|
|
|
return stream.readable; |
|
} |
|
} |
|
|
|
const jwtPromises: Map<string, Promise<{ accessToken: string; casUrl: string }>> = new Map(); |
|
|
|
|
|
|
|
const jwts: Map< |
|
string, |
|
{ |
|
accessToken: string; |
|
expiresAt: Date; |
|
casUrl: string; |
|
} |
|
> = new Map(); |
|
|
|
function cacheKey(params: { refreshUrl: string; initialAccessToken: string | undefined }): string { |
|
return JSON.stringify([params.refreshUrl, params.initialAccessToken]); |
|
} |
|
|
|
|
|
export function bg4_regoup_bytes(bytes: Uint8Array): Uint8Array { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const split = Math.floor(bytes.byteLength / 4); |
|
const rem = bytes.byteLength % 4; |
|
const g1_pos = split + (rem >= 1 ? 1 : 0); |
|
const g2_pos = g1_pos + split + (rem >= 2 ? 1 : 0); |
|
const g3_pos = g2_pos + split + (rem == 3 ? 1 : 0); |
|
|
|
const ret = new Uint8Array(bytes.byteLength); |
|
for (let i = 0, j = 0; i < bytes.byteLength; i += 4, j++) { |
|
ret[i] = bytes[j]; |
|
} |
|
|
|
for (let i = 1, j = g1_pos; i < bytes.byteLength; i += 4, j++) { |
|
ret[i] = bytes[j]; |
|
} |
|
|
|
for (let i = 2, j = g2_pos; i < bytes.byteLength; i += 4, j++) { |
|
ret[i] = bytes[j]; |
|
} |
|
|
|
for (let i = 3, j = g3_pos; i < bytes.byteLength; i += 4, j++) { |
|
ret[i] = bytes[j]; |
|
} |
|
|
|
return ret; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
async function getAccessToken( |
|
initialAccessToken: string | undefined, |
|
customFetch: typeof fetch, |
|
refreshUrl: string |
|
): Promise<{ accessToken: string; casUrl: string }> { |
|
const key = cacheKey({ refreshUrl, initialAccessToken }); |
|
|
|
const jwt = jwts.get(key); |
|
|
|
if (jwt && jwt.expiresAt > new Date(Date.now() + JWT_SAFETY_PERIOD)) { |
|
return { accessToken: jwt.accessToken, casUrl: jwt.casUrl }; |
|
} |
|
|
|
|
|
const existingPromise = jwtPromises.get(key); |
|
if (existingPromise) { |
|
return existingPromise; |
|
} |
|
|
|
const promise = (async () => { |
|
const resp = await customFetch(refreshUrl, { |
|
headers: { |
|
...(initialAccessToken |
|
? { |
|
Authorization: `Bearer ${initialAccessToken}`, |
|
} |
|
: {}), |
|
}, |
|
}); |
|
|
|
if (!resp.ok) { |
|
throw new Error(`Failed to get JWT token: ${resp.status} ${await resp.text()}`); |
|
} |
|
|
|
const json: { accessToken: string; casUrl: string; exp: number } = await resp.json(); |
|
const jwt = { |
|
accessToken: json.accessToken, |
|
expiresAt: new Date(json.exp * 1000), |
|
initialAccessToken, |
|
refreshUrl, |
|
casUrl: json.casUrl, |
|
}; |
|
|
|
jwtPromises.delete(key); |
|
|
|
for (const [key, value] of jwts.entries()) { |
|
if (value.expiresAt < new Date(Date.now() + JWT_SAFETY_PERIOD)) { |
|
jwts.delete(key); |
|
} else { |
|
break; |
|
} |
|
} |
|
if (jwts.size >= JWT_CACHE_SIZE) { |
|
const keyToDelete = jwts.keys().next().value; |
|
if (keyToDelete) { |
|
jwts.delete(keyToDelete); |
|
} |
|
} |
|
jwts.set(key, jwt); |
|
|
|
return { |
|
accessToken: json.accessToken, |
|
casUrl: json.casUrl, |
|
}; |
|
})(); |
|
|
|
jwtPromises.set(key, promise); |
|
|
|
return promise; |
|
} |
|
|