thelou1s's picture
rename test.py
53dc546
import requests
import sys
from draw_confusion import draw, draw2
from tqdm import tqdm
DEBUG = False
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
headers = {"Authorization": "Bearer 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"}
# 处理请求
# filename = '1.flac'
def request_api(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.post(API_URL, headers=headers, data=data)
return response.json()
# 批量处理
def batch_request_api(file_uris):
if DEBUG: print('batch_request_api')
y_len = len(file_uris)
y_true = [0] * y_len
y_pred = [0] * y_len
y_idx = 0
for input_file in tqdm(file_uris):
res = request_api(input_file)
# print('%s %s:' % (str(y_idx), str(input_file)) )
# print('%s' % str(res[:3]))
first_label = str(res[0]['label'])
first_score = res[0]['score']
# print(str(first_label))
# print(str(first_score))
y_true[y_idx] = first_label
y_pred[y_idx] = round(first_score, 1)
y_idx = y_idx + 1
return y_true, y_pred
# 处理命令行
if __name__ == "__main__":
if DEBUG: print('main, ' + str(sys.argv[1:]))
if DEBUG: print('main, ' + str(len(sys.argv)))
# 获取命令行参数
if len(sys.argv) < 2:
print("用法:python x.py <文件或通配符>")
sys.exit(1)
if DEBUG: print('main, batch_request_api')
y_true, y_pred = batch_request_api(sys.argv[1:])
if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
draw2(y_true, y_pred)