Brice Vandeputte commited on
Commit
ed31d02
·
1 Parent(s): 7c32702

change api return and change logger

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -2,7 +2,7 @@ import collections
2
  import heapq
3
  import json
4
  import os
5
- import logging
6
 
7
  import gradio as gr
8
  import numpy as np
@@ -13,9 +13,9 @@ from torchvision import transforms
13
 
14
  from templates import openai_imagenet_template
15
 
16
- log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
17
- logging.basicConfig(level=logging.INFO, format=log_format)
18
- logger = logging.getLogger()
19
 
20
  hf_token = os.getenv("HF_TOKEN")
21
 
@@ -155,7 +155,7 @@ def open_domain_classification(img, rank: int) -> dict[str, float]:
155
 
156
 
157
  @torch.no_grad()
158
- def api_classification(img, rank: int) -> dict[str, float]:
159
  """
160
  Predicts from the entire tree of life.
161
  If targeting a higher rank than species, then this function predicts among all
@@ -182,10 +182,9 @@ def api_classification(img, rank: int) -> dict[str, float]:
182
 
183
  logger.info(">>>>")
184
  logger.info(probs[0])
185
-
186
- topk_names = heapq.nlargest(k, output, key=output.get)
187
-
188
- return {name: output[name] for name in topk_names}
189
 
190
 
191
  def change_output(choice):
 
2
  import heapq
3
  import json
4
  import os
5
+ from accelerate.logging import get_logger
6
 
7
  import gradio as gr
8
  import numpy as np
 
13
 
14
  from templates import openai_imagenet_template
15
 
16
+ # log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
17
+ # logging.basicConfig(level=logging.INFO, format=log_format)
18
+ logger = get_logger(__name__, log_level="INFO")
19
 
20
  hf_token = os.getenv("HF_TOKEN")
21
 
 
155
 
156
 
157
  @torch.no_grad()
158
+ def api_classification(img, rank: int): # -> dict[str, float]:
159
  """
160
  Predicts from the entire tree of life.
161
  If targeting a higher rank than species, then this function predicts among all
 
182
 
183
  logger.info(">>>>")
184
  logger.info(probs[0])
185
+ return probs[0]
186
+ # topk_names = heapq.nlargest(k, output, key=output.get)
187
+ # return {name: output[name] for name in topk_names}
 
188
 
189
 
190
  def change_output(choice):