| | package hfapi |
| |
|
| | import ( |
| | "encoding/json" |
| | "fmt" |
| | "io" |
| | "net/http" |
| | "path/filepath" |
| | "strings" |
| | ) |
| |
|
| | |
| | type Model struct { |
| | ModelID string `json:"modelId"` |
| | Author string `json:"author"` |
| | Downloads int `json:"downloads"` |
| | LastModified string `json:"lastModified"` |
| | PipelineTag string `json:"pipelineTag"` |
| | Private bool `json:"private"` |
| | Tags []string `json:"tags"` |
| | CreatedAt string `json:"createdAt"` |
| | UpdatedAt string `json:"updatedAt"` |
| | Sha string `json:"sha"` |
| | Config map[string]interface{} `json:"config"` |
| | ModelIndex string `json:"model_index"` |
| | LibraryName string `json:"library_name"` |
| | MaskToken string `json:"mask_token"` |
| | TokenizerClass string `json:"tokenizer_class"` |
| | } |
| |
|
| | |
| | type FileInfo struct { |
| | Type string `json:"type"` |
| | Oid string `json:"oid"` |
| | Size int64 `json:"size"` |
| | Path string `json:"path"` |
| | LFS *LFSInfo `json:"lfs,omitempty"` |
| | XetHash string `json:"xetHash,omitempty"` |
| | } |
| |
|
| | |
| | type LFSInfo struct { |
| | Oid string `json:"oid"` |
| | Size int64 `json:"size"` |
| | PointerSize int `json:"pointerSize"` |
| | } |
| |
|
| | |
| | type ModelFile struct { |
| | Path string |
| | Size int64 |
| | SHA256 string |
| | IsReadme bool |
| | URL string |
| | } |
| |
|
| | |
| | type ModelDetails struct { |
| | ModelID string |
| | Author string |
| | Files []ModelFile |
| | ReadmeFile *ModelFile |
| | ReadmeContent string |
| | } |
| |
|
| | |
| | type SearchParams struct { |
| | Sort string `json:"sort"` |
| | Direction int `json:"direction"` |
| | Limit int `json:"limit"` |
| | Search string `json:"search"` |
| | } |
| |
|
| | |
| | type Client struct { |
| | baseURL string |
| | client *http.Client |
| | } |
| |
|
| | |
| | func NewClient() *Client { |
| | return &Client{ |
| | baseURL: "https://huggingface.co/api/models", |
| | client: &http.Client{}, |
| | } |
| | } |
| |
|
| | |
| | func (c *Client) SearchModels(params SearchParams) ([]Model, error) { |
| | req, err := http.NewRequest("GET", c.baseURL, nil) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to create request: %w", err) |
| | } |
| |
|
| | |
| | q := req.URL.Query() |
| | q.Add("sort", params.Sort) |
| | q.Add("direction", fmt.Sprintf("%d", params.Direction)) |
| | q.Add("limit", fmt.Sprintf("%d", params.Limit)) |
| | q.Add("search", params.Search) |
| | req.URL.RawQuery = q.Encode() |
| |
|
| | |
| | resp, err := c.client.Do(req) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to make request: %w", err) |
| | } |
| | defer resp.Body.Close() |
| |
|
| | if resp.StatusCode != http.StatusOK { |
| | return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode) |
| | } |
| |
|
| | |
| | body, err := io.ReadAll(resp.Body) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to read response body: %w", err) |
| | } |
| |
|
| | |
| | var models []Model |
| | if err := json.Unmarshal(body, &models); err != nil { |
| | return nil, fmt.Errorf("failed to parse JSON response: %w", err) |
| | } |
| |
|
| | return models, nil |
| | } |
| |
|
| | |
| | func (c *Client) GetLatest(searchTerm string, limit int) ([]Model, error) { |
| | params := SearchParams{ |
| | Sort: "lastModified", |
| | Direction: -1, |
| | Limit: limit, |
| | Search: searchTerm, |
| | } |
| |
|
| | return c.SearchModels(params) |
| | } |
| |
|
| | |
| | func (c *Client) BaseURL() string { |
| | return c.baseURL |
| | } |
| |
|
| | |
| | func (c *Client) SetBaseURL(url string) { |
| | c.baseURL = url |
| | } |
| |
|
| | |
| | func (c *Client) listFilesInPath(repoID, path string) ([]FileInfo, error) { |
| | baseURL := strings.TrimSuffix(c.baseURL, "/api/models") |
| | var url string |
| | if path == "" { |
| | url = fmt.Sprintf("%s/api/models/%s/tree/main", baseURL, repoID) |
| | } else { |
| | url = fmt.Sprintf("%s/api/models/%s/tree/main/%s", baseURL, repoID, path) |
| | } |
| |
|
| | req, err := http.NewRequest("GET", url, nil) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to create request: %w", err) |
| | } |
| |
|
| | resp, err := c.client.Do(req) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to make request: %w", err) |
| | } |
| | defer resp.Body.Close() |
| |
|
| | if resp.StatusCode != http.StatusOK { |
| | return nil, fmt.Errorf("failed to fetch files. Status code: %d", resp.StatusCode) |
| | } |
| |
|
| | body, err := io.ReadAll(resp.Body) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to read response body: %w", err) |
| | } |
| |
|
| | var items []FileInfo |
| | if err := json.Unmarshal(body, &items); err != nil { |
| | return nil, fmt.Errorf("failed to parse JSON response: %w", err) |
| | } |
| |
|
| | var allFiles []FileInfo |
| | for _, item := range items { |
| | switch item.Type { |
| | |
| | case "directory", "folder": |
| | |
| | subPath := item.Path |
| | if path != "" { |
| | subPath = fmt.Sprintf("%s/%s", path, item.Path) |
| | } |
| |
|
| | |
| | |
| | subFiles, err := c.listFilesInPath(repoID, subPath) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to list files in subfolder %s: %w", subPath, err) |
| | } |
| |
|
| | allFiles = append(allFiles, subFiles...) |
| | case "file": |
| | |
| | |
| | |
| | |
| | allFiles = append(allFiles, item) |
| | } |
| | } |
| |
|
| | return allFiles, nil |
| | } |
| |
|
| | |
| | func (c *Client) ListFiles(repoID string) ([]FileInfo, error) { |
| | return c.listFilesInPath(repoID, "") |
| | } |
| |
|
| | |
| | func (c *Client) GetFileSHA(repoID, fileName string) (string, error) { |
| | files, err := c.ListFiles(repoID) |
| | if err != nil { |
| | return "", fmt.Errorf("failed to list files while getting SHA: %w", err) |
| | } |
| |
|
| | for _, file := range files { |
| | if filepath.Base(file.Path) == fileName { |
| | if file.LFS != nil && file.LFS.Oid != "" { |
| | |
| | return file.LFS.Oid, nil |
| | } |
| | |
| | return file.Oid, nil |
| | } |
| | } |
| |
|
| | return "", fmt.Errorf("file %s not found", fileName) |
| | } |
| |
|
| | |
| | func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) { |
| | files, err := c.ListFiles(repoID) |
| | if err != nil { |
| | return nil, fmt.Errorf("failed to list files: %w", err) |
| | } |
| |
|
| | details := &ModelDetails{ |
| | ModelID: repoID, |
| | Author: strings.Split(repoID, "/")[0], |
| | Files: make([]ModelFile, 0, len(files)), |
| | } |
| |
|
| | |
| | baseURL := strings.TrimSuffix(c.baseURL, "/api/models") |
| | for _, file := range files { |
| | fileName := filepath.Base(file.Path) |
| | isReadme := strings.Contains(strings.ToLower(fileName), "readme") |
| |
|
| | |
| | sha256 := "" |
| | if file.LFS != nil && file.LFS.Oid != "" { |
| | sha256 = file.LFS.Oid |
| | } else { |
| | sha256 = file.Oid |
| | } |
| |
|
| | |
| | |
| | fileURL := fmt.Sprintf("%s/%s/resolve/main/%s", baseURL, repoID, file.Path) |
| |
|
| | modelFile := ModelFile{ |
| | Path: file.Path, |
| | Size: file.Size, |
| | SHA256: sha256, |
| | IsReadme: isReadme, |
| | URL: fileURL, |
| | } |
| |
|
| | details.Files = append(details.Files, modelFile) |
| |
|
| | |
| | if isReadme && details.ReadmeFile == nil { |
| | details.ReadmeFile = &modelFile |
| | } |
| | } |
| |
|
| | return details, nil |
| | } |
| |
|
| | |
| | func (c *Client) GetReadmeContent(repoID, readmePath string) (string, error) { |
| | baseURL := strings.TrimSuffix(c.baseURL, "/api/models") |
| | url := fmt.Sprintf("%s/%s/raw/main/%s", baseURL, repoID, readmePath) |
| |
|
| | req, err := http.NewRequest("GET", url, nil) |
| | if err != nil { |
| | return "", fmt.Errorf("failed to create request: %w", err) |
| | } |
| |
|
| | resp, err := c.client.Do(req) |
| | if err != nil { |
| | return "", fmt.Errorf("failed to make request: %w", err) |
| | } |
| | defer resp.Body.Close() |
| |
|
| | if resp.StatusCode != http.StatusOK { |
| | return "", fmt.Errorf("failed to fetch readme content. Status code: %d", resp.StatusCode) |
| | } |
| |
|
| | body, err := io.ReadAll(resp.Body) |
| | if err != nil { |
| | return "", fmt.Errorf("failed to read response body: %w", err) |
| | } |
| |
|
| | return string(body), nil |
| | } |
| |
|
| | |
| | func FilterFilesByQuantization(files []ModelFile, quantization string) []ModelFile { |
| | var filtered []ModelFile |
| | for _, file := range files { |
| | fileName := filepath.Base(file.Path) |
| | if strings.Contains(strings.ToLower(fileName), strings.ToLower(quantization)) { |
| | filtered = append(filtered, file) |
| | } |
| | } |
| | return filtered |
| | } |
| |
|
| | |
| | func FindPreferredModelFile(files []ModelFile, preferences []string) *ModelFile { |
| | for _, preference := range preferences { |
| | for i := range files { |
| | fileName := filepath.Base(files[i].Path) |
| | if strings.Contains(strings.ToLower(fileName), strings.ToLower(preference)) { |
| | return &files[i] |
| | } |
| | } |
| | } |
| | return nil |
| | } |
| |
|