m-ric HF staff commited on
Commit
66c8bc6
·
verified ·
1 Parent(s): aee1d4b

Upload tool

Browse files
Files changed (3) hide show
  1. app.py +4 -2
  2. requirements.txt +1 -0
  3. tool.py +11 -5
app.py CHANGED
@@ -1,4 +1,6 @@
1
- from transformers import launch_gradio_demo
2
  from tool import HFModelDownloadsTool
3
 
4
- launch_gradio_demo(HFModelDownloadsTool)
 
 
 
1
+ from smolagents import launch_gradio_demo
2
  from tool import HFModelDownloadsTool
3
 
4
+ tool = HFModelDownloadsTool()
5
+
6
+ launch_gradio_demo(tool)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  agents_package
2
  huggingface_hub
 
 
1
  agents_package
2
  huggingface_hub
3
+ smolagents
tool.py CHANGED
@@ -1,15 +1,21 @@
1
- from agents import Tool
 
2
 
3
  class HFModelDownloadsTool(Tool):
4
  name = "model_download_counter"
5
  description = """
6
  This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
7
  It returns the name of the checkpoint."""
8
- inputs = {"task":{"type":"string","description":"the task category (such as text-classification, depth-estimation, etc)"}}
9
  output_type = "string"
10
 
11
  def forward(self, task: str):
12
- from huggingface_hub import list_models
 
 
 
 
 
 
 
13
 
14
- model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
15
- return model.id
 
1
+ from smolagents.tools import Tool
2
+ import huggingface_hub
3
 
4
  class HFModelDownloadsTool(Tool):
5
  name = "model_download_counter"
6
  description = """
7
  This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
8
  It returns the name of the checkpoint."""
9
+ inputs = {'task': {'type': 'string', 'description': 'the task category (such as text-classification, depth-estimation, etc)'}}
10
  output_type = "string"
11
 
12
  def forward(self, task: str):
13
+ from huggingface_hub import list_models
14
+
15
+ model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
16
+ return model.id
17
+
18
+
19
+ def __init__(self, *args, **kwargs):
20
+ self.is_initialized = False
21