|
import { describe, expect, it } from "vitest"; |
|
import type { ModelEntry } from "./list-models"; |
|
import { listModels } from "./list-models"; |
|
|
|
describe("listModels", () => { |
|
it("should list models for depth estimation", async () => { |
|
const results: ModelEntry[] = []; |
|
|
|
for await (const entry of listModels({ |
|
search: { owner: "Intel", task: "depth-estimation" }, |
|
})) { |
|
if (typeof entry.downloads === "number") { |
|
entry.downloads = 0; |
|
} |
|
if (typeof entry.likes === "number") { |
|
entry.likes = 0; |
|
} |
|
if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) { |
|
entry.updatedAt = new Date(0); |
|
} |
|
|
|
if (!["Intel/dpt-large", "Intel/dpt-hybrid-midas"].includes(entry.name)) { |
|
expect(entry.task).to.equal("depth-estimation"); |
|
continue; |
|
} |
|
|
|
results.push(entry); |
|
} |
|
|
|
results.sort((a, b) => a.id.localeCompare(b.id)); |
|
|
|
expect(results).deep.equal([ |
|
{ |
|
id: "621ffdc136468d709f17e709", |
|
name: "Intel/dpt-large", |
|
private: false, |
|
gated: false, |
|
downloads: 0, |
|
likes: 0, |
|
task: "depth-estimation", |
|
updatedAt: new Date(0), |
|
}, |
|
{ |
|
id: "638f07977559bf9a2b2b04ac", |
|
name: "Intel/dpt-hybrid-midas", |
|
gated: false, |
|
private: false, |
|
downloads: 0, |
|
likes: 0, |
|
task: "depth-estimation", |
|
updatedAt: new Date(0), |
|
}, |
|
]); |
|
}); |
|
|
|
it("should list indonesian models with gguf format", async () => { |
|
let count = 0; |
|
for await (const entry of listModels({ |
|
search: { tags: ["gguf", "id"] }, |
|
additionalFields: ["tags"], |
|
limit: 2, |
|
})) { |
|
count++; |
|
expect(entry.tags).to.include("gguf"); |
|
expect(entry.tags).to.include("id"); |
|
} |
|
|
|
expect(count).to.equal(2); |
|
}); |
|
|
|
it("should search model by name", async () => { |
|
let count = 0; |
|
for await (const entry of listModels({ |
|
search: { query: "t5" }, |
|
limit: 10, |
|
})) { |
|
count++; |
|
expect(entry.name.toLocaleLowerCase()).to.include("t5"); |
|
} |
|
|
|
expect(count).to.equal(10); |
|
}); |
|
|
|
it("should search model by inference provider", async () => { |
|
let count = 0; |
|
for await (const entry of listModels({ |
|
search: { inferenceProviders: ["together"] }, |
|
additionalFields: ["inferenceProviderMapping"], |
|
limit: 10, |
|
})) { |
|
count++; |
|
if (Array.isArray(entry.inferenceProviderMapping)) { |
|
expect(entry.inferenceProviderMapping.map(({ provider }) => provider)).to.include("together"); |
|
} |
|
} |
|
|
|
expect(count).to.equal(10); |
|
}); |
|
|
|
it("should search model by several inference providers", async () => { |
|
let count = 0; |
|
const inferenceProviders = ["together", "replicate"]; |
|
for await (const entry of listModels({ |
|
search: { inferenceProviders }, |
|
additionalFields: ["inferenceProviderMapping"], |
|
limit: 10, |
|
})) { |
|
count++; |
|
if (Array.isArray(entry.inferenceProviderMapping)) { |
|
expect( |
|
entry.inferenceProviderMapping.filter(({ provider }) => inferenceProviders.includes(provider)).length |
|
).toBeGreaterThan(0); |
|
} |
|
} |
|
|
|
expect(count).to.equal(10); |
|
}); |
|
}); |
|
|