Spaces:
Configuration error
Configuration error
package downloader | |
import ( | |
"crypto/sha256" | |
"fmt" | |
"io" | |
"net/http" | |
"net/url" | |
"os" | |
"path/filepath" | |
"strconv" | |
"strings" | |
ocispec "github.com/opencontainers/image-spec/specs-go/v1" | |
"github.com/mudler/LocalAI/pkg/oci" | |
"github.com/mudler/LocalAI/pkg/utils" | |
"github.com/rs/zerolog/log" | |
) | |
const ( | |
HuggingFacePrefix = "huggingface://" | |
OCIPrefix = "oci://" | |
OllamaPrefix = "ollama://" | |
HTTPPrefix = "http://" | |
HTTPSPrefix = "https://" | |
GithubURI = "github:" | |
GithubURI2 = "github://" | |
LocalPrefix = "file://" | |
) | |
type URI string | |
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error { | |
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f) | |
} | |
func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error { | |
url := uri.ResolveURL() | |
if strings.HasPrefix(url, LocalPrefix) { | |
rawURL := strings.TrimPrefix(url, LocalPrefix) | |
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified. | |
resolvedFile, err := filepath.EvalSymlinks(rawURL) | |
if err != nil { | |
return err | |
} | |
resolvedBasePath, err := filepath.EvalSymlinks(basePath) | |
if err != nil { | |
return err | |
} | |
// Check if the local file is rooted in basePath | |
err = utils.InTrustedRoot(resolvedFile, resolvedBasePath) | |
if err != nil { | |
log.Debug().Str("resolvedFile", resolvedFile).Str("basePath", basePath).Msg("downloader.GetURI blocked an attempt to ready a file url outside of basePath") | |
return err | |
} | |
// Read the response body | |
body, err := os.ReadFile(resolvedFile) | |
if err != nil { | |
return err | |
} | |
// Unmarshal YAML data into a struct | |
return f(url, body) | |
} | |
// Send a GET request to the URL | |
req, err := http.NewRequest("GET", url, nil) | |
if err != nil { | |
return err | |
} | |
if authorization != "" { | |
req.Header.Add("Authorization", authorization) | |
} | |
response, err := http.DefaultClient.Do(req) | |
if err != nil { | |
return err | |
} | |
defer response.Body.Close() | |
// Read the response body | |
body, err := io.ReadAll(response.Body) | |
if err != nil { | |
return err | |
} | |
// Unmarshal YAML data into a struct | |
return f(url, body) | |
} | |
func (u URI) FilenameFromUrl() (string, error) { | |
f, err := filenameFromUrl(string(u)) | |
if err != nil || f == "" { | |
f = utils.MD5(string(u)) | |
if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") { | |
f = f + ".yaml" | |
} | |
err = nil | |
} | |
return f, err | |
} | |
func filenameFromUrl(urlstr string) (string, error) { | |
// strip anything after @ | |
if strings.Contains(urlstr, "@") { | |
urlstr = strings.Split(urlstr, "@")[0] | |
} | |
u, err := url.Parse(urlstr) | |
if err != nil { | |
return "", fmt.Errorf("error due to parsing url: %w", err) | |
} | |
x, err := url.QueryUnescape(u.EscapedPath()) | |
if err != nil { | |
return "", fmt.Errorf("error due to escaping: %w", err) | |
} | |
return filepath.Base(x), nil | |
} | |
func (u URI) LooksLikeURL() bool { | |
return strings.HasPrefix(string(u), HTTPPrefix) || | |
strings.HasPrefix(string(u), HTTPSPrefix) || | |
strings.HasPrefix(string(u), HuggingFacePrefix) || | |
strings.HasPrefix(string(u), GithubURI) || | |
strings.HasPrefix(string(u), OllamaPrefix) || | |
strings.HasPrefix(string(u), OCIPrefix) || | |
strings.HasPrefix(string(u), GithubURI2) | |
} | |
func (s URI) LooksLikeOCI() bool { | |
return strings.HasPrefix(string(s), OCIPrefix) || strings.HasPrefix(string(s), OllamaPrefix) | |
} | |
func (s URI) ResolveURL() string { | |
switch { | |
case strings.HasPrefix(string(s), GithubURI2): | |
repository := strings.Replace(string(s), GithubURI2, "", 1) | |
repoParts := strings.Split(repository, "@") | |
branch := "main" | |
if len(repoParts) > 1 { | |
branch = repoParts[1] | |
} | |
repoPath := strings.Split(repoParts[0], "/") | |
org := repoPath[0] | |
project := repoPath[1] | |
projectPath := strings.Join(repoPath[2:], "/") | |
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) | |
case strings.HasPrefix(string(s), GithubURI): | |
parts := strings.Split(string(s), ":") | |
repoParts := strings.Split(parts[1], "@") | |
branch := "main" | |
if len(repoParts) > 1 { | |
branch = repoParts[1] | |
} | |
repoPath := strings.Split(repoParts[0], "/") | |
org := repoPath[0] | |
project := repoPath[1] | |
projectPath := strings.Join(repoPath[2:], "/") | |
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) | |
case strings.HasPrefix(string(s), HuggingFacePrefix): | |
repository := strings.Replace(string(s), HuggingFacePrefix, "", 1) | |
// convert repository to a full URL. | |
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf | |
owner := strings.Split(repository, "/")[0] | |
repo := strings.Split(repository, "/")[1] | |
branch := "main" | |
if strings.Contains(repo, "@") { | |
branch = strings.Split(repository, "@")[1] | |
} | |
filepath := strings.Split(repository, "/")[2] | |
if strings.Contains(filepath, "@") { | |
filepath = strings.Split(filepath, "@")[0] | |
} | |
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) | |
} | |
return string(s) | |
} | |
func removePartialFile(tmpFilePath string) error { | |
_, err := os.Stat(tmpFilePath) | |
if err == nil { | |
log.Debug().Msgf("Removing temporary file %s", tmpFilePath) | |
err = os.Remove(tmpFilePath) | |
if err != nil { | |
err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err) | |
log.Warn().Msg(err1.Error()) | |
return err1 | |
} | |
} | |
return nil | |
} | |
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { | |
url := uri.ResolveURL() | |
if uri.LooksLikeOCI() { | |
progressStatus := func(desc ocispec.Descriptor) io.Writer { | |
return &progressWriter{ | |
fileName: filePath, | |
total: desc.Size, | |
hash: sha256.New(), | |
fileNo: fileN, | |
totalFiles: total, | |
downloadStatus: downloadStatus, | |
} | |
} | |
if strings.HasPrefix(url, OllamaPrefix) { | |
url = strings.TrimPrefix(url, OllamaPrefix) | |
return oci.OllamaFetchModel(url, filePath, progressStatus) | |
} | |
url = strings.TrimPrefix(url, OCIPrefix) | |
img, err := oci.GetImage(url, "", nil, nil) | |
if err != nil { | |
return fmt.Errorf("failed to get image %q: %v", url, err) | |
} | |
return oci.ExtractOCIImage(img, filepath.Dir(filePath)) | |
} | |
// Check if the file already exists | |
_, err := os.Stat(filePath) | |
if err == nil { | |
// File exists, check SHA | |
if sha != "" { | |
// Verify SHA | |
calculatedSHA, err := calculateSHA(filePath) | |
if err != nil { | |
return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err) | |
} | |
if calculatedSHA == sha { | |
// SHA matches, skip downloading | |
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", filePath) | |
return nil | |
} | |
// SHA doesn't match, delete the file and download again | |
err = os.Remove(filePath) | |
if err != nil { | |
return fmt.Errorf("failed to remove existing file %q: %v", filePath, err) | |
} | |
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) | |
} else { | |
// SHA is missing, skip downloading | |
log.Debug().Msgf("File %q already exists. Skipping download", filePath) | |
return nil | |
} | |
} else if !os.IsNotExist(err) { | |
// Error occurred while checking file existence | |
return fmt.Errorf("failed to check file %q existence: %v", filePath, err) | |
} | |
log.Info().Msgf("Downloading %q", url) | |
// Download file | |
resp, err := http.Get(url) | |
if err != nil { | |
return fmt.Errorf("failed to download file %q: %v", filePath, err) | |
} | |
defer resp.Body.Close() | |
if resp.StatusCode >= 400 { | |
return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) | |
} | |
// Create parent directory | |
err = os.MkdirAll(filepath.Dir(filePath), 0750) | |
if err != nil { | |
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err) | |
} | |
// save partial download to dedicated file | |
tmpFilePath := filePath + ".partial" | |
// remove tmp file | |
err = removePartialFile(tmpFilePath) | |
if err != nil { | |
return err | |
} | |
// Create and write file content | |
outFile, err := os.Create(tmpFilePath) | |
if err != nil { | |
return fmt.Errorf("failed to create file %q: %v", tmpFilePath, err) | |
} | |
defer outFile.Close() | |
progress := &progressWriter{ | |
fileName: tmpFilePath, | |
total: resp.ContentLength, | |
hash: sha256.New(), | |
fileNo: fileN, | |
totalFiles: total, | |
downloadStatus: downloadStatus, | |
} | |
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) | |
if err != nil { | |
return fmt.Errorf("failed to write file %q: %v", filePath, err) | |
} | |
err = os.Rename(tmpFilePath, filePath) | |
if err != nil { | |
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) | |
} | |
if sha != "" { | |
// Verify SHA | |
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) | |
if calculatedSHA != sha { | |
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) | |
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) | |
} | |
} else { | |
log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath) | |
} | |
log.Info().Msgf("File %q downloaded and verified", filePath) | |
if utils.IsArchive(filePath) { | |
basePath := filepath.Dir(filePath) | |
log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath) | |
if err := utils.ExtractArchive(filePath, basePath); err != nil { | |
log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error()) | |
return err | |
} | |
} | |
return nil | |
} | |
func formatBytes(bytes int64) string { | |
const unit = 1024 | |
if bytes < unit { | |
return strconv.FormatInt(bytes, 10) + " B" | |
} | |
div, exp := int64(unit), 0 | |
for n := bytes / unit; n >= unit; n /= unit { | |
div *= unit | |
exp++ | |
} | |
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) | |
} | |
func calculateSHA(filePath string) (string, error) { | |
file, err := os.Open(filePath) | |
if err != nil { | |
return "", err | |
} | |
defer file.Close() | |
hash := sha256.New() | |
if _, err := io.Copy(hash, file); err != nil { | |
return "", err | |
} | |
return fmt.Sprintf("%x", hash.Sum(nil)), nil | |
} | |