File size: 1,601 Bytes
53dc546
 
 
 
400bffb
 
53dc546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import requests
import sys
from draw_confusion import draw, draw2
from tqdm import tqdm


DEBUG = False
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
headers = {"Authorization": "Bearer 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"}


# 处理请求
# filename = '1.flac'
def request_api(filename):
    with open(filename, "rb") as f:
        data = f.read()
    response = requests.post(API_URL, headers=headers, data=data)
    return response.json()


# 批量处理
def batch_request_api(file_uris):
    if DEBUG: print('batch_request_api')
    y_len = len(file_uris)
    y_true = [0] * y_len
    y_pred = [0] * y_len
    y_idx = 0
    for input_file in tqdm(file_uris):
        res = request_api(input_file)
        # print('%s %s:' % (str(y_idx), str(input_file)) )
        # print('%s' % str(res[:3]))
        
        first_label = str(res[0]['label'])
        first_score = res[0]['score']
        # print(str(first_label))
        # print(str(first_score))

        y_true[y_idx] = first_label
        y_pred[y_idx] = round(first_score, 1)
        y_idx = y_idx + 1
    
    return y_true, y_pred


# 处理命令行
if __name__ == "__main__":
	if DEBUG: print('main, ' + str(sys.argv[1:]))	
	if DEBUG: print('main, ' + str(len(sys.argv)))

	# 获取命令行参数
	if len(sys.argv) < 2:
		print("用法:python x.py <文件或通配符>")
		sys.exit(1)
	
	if DEBUG: print('main, batch_request_api')
	y_true, y_pred = batch_request_api(sys.argv[1:])
	if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
	draw2(y_true, y_pred)