import requests import sys from draw_confusion import draw, draw2 from tqdm import tqdm DEBUG = False API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593" headers = {"Authorization": "Bearer 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"} # 处理请求 # filename = '1.flac' def request_api(filename): with open(filename, "rb") as f: data = f.read() response = requests.post(API_URL, headers=headers, data=data) return response.json() # 批量处理 def batch_request_api(file_uris): if DEBUG: print('batch_request_api') y_len = len(file_uris) y_true = [0] * y_len y_pred = [0] * y_len y_idx = 0 for input_file in tqdm(file_uris): res = request_api(input_file) # print('%s %s:' % (str(y_idx), str(input_file)) ) # print('%s' % str(res[:3])) first_label = str(res[0]['label']) first_score = res[0]['score'] # print(str(first_label)) # print(str(first_score)) y_true[y_idx] = first_label y_pred[y_idx] = round(first_score, 1) y_idx = y_idx + 1 return y_true, y_pred # 处理命令行 if __name__ == "__main__": if DEBUG: print('main, ' + str(sys.argv[1:])) if DEBUG: print('main, ' + str(len(sys.argv))) # 获取命令行参数 if len(sys.argv) < 2: print("用法:python x.py <文件或通配符>") sys.exit(1) if DEBUG: print('main, batch_request_api') y_true, y_pred = batch_request_api(sys.argv[1:]) if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred))) draw2(y_true, y_pred)