rename test.py
Browse files- test.py +56 -6
- test_dreamtalk.py +0 -58
test.py
CHANGED
@@ -1,8 +1,58 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
|
3 |
-
classifier = pipeline(model="superb/wav2vec2-base-superb-ks")
|
4 |
-
res = classifier("1.flac")
|
5 |
-
#res = classifier("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import sys
|
3 |
+
from draw_confusion import draw, draw2
|
4 |
+
from tqdm import tqdm
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
+
DEBUG = False
|
8 |
+
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
|
9 |
+
headers = {"Authorization": "Bearer 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"}
|
10 |
+
|
11 |
+
|
12 |
+
# 处理请求
|
13 |
+
# filename = '1.flac'
|
14 |
+
def request_api(filename):
|
15 |
+
with open(filename, "rb") as f:
|
16 |
+
data = f.read()
|
17 |
+
response = requests.post(API_URL, headers=headers, data=data)
|
18 |
+
return response.json()
|
19 |
+
|
20 |
+
|
21 |
+
# 批量处理
|
22 |
+
def batch_request_api(file_uris):
|
23 |
+
if DEBUG: print('batch_request_api')
|
24 |
+
y_len = len(file_uris)
|
25 |
+
y_true = [0] * y_len
|
26 |
+
y_pred = [0] * y_len
|
27 |
+
y_idx = 0
|
28 |
+
for input_file in tqdm(file_uris):
|
29 |
+
res = request_api(input_file)
|
30 |
+
# print('%s %s:' % (str(y_idx), str(input_file)) )
|
31 |
+
# print('%s' % str(res[:3]))
|
32 |
+
|
33 |
+
first_label = str(res[0]['label'])
|
34 |
+
first_score = res[0]['score']
|
35 |
+
# print(str(first_label))
|
36 |
+
# print(str(first_score))
|
37 |
+
|
38 |
+
y_true[y_idx] = first_label
|
39 |
+
y_pred[y_idx] = round(first_score, 1)
|
40 |
+
y_idx = y_idx + 1
|
41 |
+
|
42 |
+
return y_true, y_pred
|
43 |
+
|
44 |
+
|
45 |
+
# 处理命令行
|
46 |
+
if __name__ == "__main__":
|
47 |
+
if DEBUG: print('main, ' + str(sys.argv[1:]))
|
48 |
+
if DEBUG: print('main, ' + str(len(sys.argv)))
|
49 |
+
|
50 |
+
# 获取命令行参数
|
51 |
+
if len(sys.argv) < 2:
|
52 |
+
print("用法:python x.py <文件或通配符>")
|
53 |
+
sys.exit(1)
|
54 |
+
|
55 |
+
if DEBUG: print('main, batch_request_api')
|
56 |
+
y_true, y_pred = batch_request_api(sys.argv[1:])
|
57 |
+
if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
|
58 |
+
draw2(y_true, y_pred)
|
test_dreamtalk.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
import requests
|
2 |
-
import sys
|
3 |
-
from draw_confusion import draw, draw2
|
4 |
-
from tqdm import tqdm
|
5 |
-
|
6 |
-
|
7 |
-
DEBUG = True #False
|
8 |
-
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
|
9 |
-
headers = {"Authorization": "Bearer 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"}
|
10 |
-
|
11 |
-
|
12 |
-
# 处理请求
|
13 |
-
# filename = '1.flac'
|
14 |
-
def request_api(filename):
|
15 |
-
with open(filename, "rb") as f:
|
16 |
-
data = f.read()
|
17 |
-
response = requests.post(API_URL, headers=headers, data=data)
|
18 |
-
return response.json()
|
19 |
-
|
20 |
-
|
21 |
-
# 批量处理
|
22 |
-
def batch_request_api(file_uris):
|
23 |
-
if DEBUG: print('batch_request_api')
|
24 |
-
y_len = len(file_uris)
|
25 |
-
y_true = [0] * y_len
|
26 |
-
y_pred = [0] * y_len
|
27 |
-
y_idx = 0
|
28 |
-
for input_file in tqdm(file_uris):
|
29 |
-
res = request_api(input_file)
|
30 |
-
# print('%s %s:' % (str(y_idx), str(input_file)) )
|
31 |
-
# print('%s' % str(res[:3]))
|
32 |
-
|
33 |
-
first_label = str(res[0]['label'])
|
34 |
-
first_score = res[0]['score']
|
35 |
-
# print(str(first_label))
|
36 |
-
# print(str(first_score))
|
37 |
-
|
38 |
-
y_true[y_idx] = first_label
|
39 |
-
y_pred[y_idx] = round(first_score, 1)
|
40 |
-
y_idx = y_idx + 1
|
41 |
-
|
42 |
-
return y_true, y_pred
|
43 |
-
|
44 |
-
|
45 |
-
# 处理命令行
|
46 |
-
if __name__ == "__main__":
|
47 |
-
if DEBUG: print('main, ' + str(sys.argv[1:]))
|
48 |
-
if DEBUG: print('main, ' + str(len(sys.argv)))
|
49 |
-
|
50 |
-
# 获取命令行参数
|
51 |
-
if len(sys.argv) < 2:
|
52 |
-
print("用法:python x.py <文件或通配符>")
|
53 |
-
sys.exit(1)
|
54 |
-
|
55 |
-
if DEBUG: print('main, batch_request_api')
|
56 |
-
y_true, y_pred = batch_request_api(sys.argv[1:])
|
57 |
-
if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
|
58 |
-
draw2(y_true, y_pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|