thelou1s commited on
Commit
53dc546
1 Parent(s): 49a9586

rename test.py

Browse files
Files changed (2) hide show
  1. test.py +56 -6
  2. test_dreamtalk.py +0 -58
test.py CHANGED
@@ -1,8 +1,58 @@
1
- from transformers import pipeline
 
 
 
2
 
3
- classifier = pipeline(model="superb/wav2vec2-base-superb-ks")
4
- res = classifier("1.flac")
5
- #res = classifier("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
6
 
7
- print(str(res))
8
- #[{'score': 0.997, 'label': '_unknown_'}, {'score': 0.002, 'label': 'left'}, {'score': 0.0, 'label': 'yes'}, {'score': 0.0, 'label': 'down'}, {'score': 0.0, 'label': 'stop'}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import sys
3
+ from draw_confusion import draw, draw2
4
+ from tqdm import tqdm
5
 
 
 
 
6
 
7
+ DEBUG = False
8
+ API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
9
+ headers = {"Authorization": "Bearer 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"}
10
+
11
+
12
+ # 处理请求
13
+ # filename = '1.flac'
14
+ def request_api(filename):
15
+ with open(filename, "rb") as f:
16
+ data = f.read()
17
+ response = requests.post(API_URL, headers=headers, data=data)
18
+ return response.json()
19
+
20
+
21
+ # 批量处理
22
+ def batch_request_api(file_uris):
23
+ if DEBUG: print('batch_request_api')
24
+ y_len = len(file_uris)
25
+ y_true = [0] * y_len
26
+ y_pred = [0] * y_len
27
+ y_idx = 0
28
+ for input_file in tqdm(file_uris):
29
+ res = request_api(input_file)
30
+ # print('%s %s:' % (str(y_idx), str(input_file)) )
31
+ # print('%s' % str(res[:3]))
32
+
33
+ first_label = str(res[0]['label'])
34
+ first_score = res[0]['score']
35
+ # print(str(first_label))
36
+ # print(str(first_score))
37
+
38
+ y_true[y_idx] = first_label
39
+ y_pred[y_idx] = round(first_score, 1)
40
+ y_idx = y_idx + 1
41
+
42
+ return y_true, y_pred
43
+
44
+
45
+ # 处理命令行
46
+ if __name__ == "__main__":
47
+ if DEBUG: print('main, ' + str(sys.argv[1:]))
48
+ if DEBUG: print('main, ' + str(len(sys.argv)))
49
+
50
+ # 获取命令行参数
51
+ if len(sys.argv) < 2:
52
+ print("用法:python x.py <文件或通配符>")
53
+ sys.exit(1)
54
+
55
+ if DEBUG: print('main, batch_request_api')
56
+ y_true, y_pred = batch_request_api(sys.argv[1:])
57
+ if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
58
+ draw2(y_true, y_pred)
test_dreamtalk.py DELETED
@@ -1,58 +0,0 @@
1
- import requests
2
- import sys
3
- from draw_confusion import draw, draw2
4
- from tqdm import tqdm
5
-
6
-
7
- DEBUG = True #False
8
- API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
9
- headers = {"Authorization": "Bearer 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"}
10
-
11
-
12
- # 处理请求
13
- # filename = '1.flac'
14
- def request_api(filename):
15
- with open(filename, "rb") as f:
16
- data = f.read()
17
- response = requests.post(API_URL, headers=headers, data=data)
18
- return response.json()
19
-
20
-
21
- # 批量处理
22
- def batch_request_api(file_uris):
23
- if DEBUG: print('batch_request_api')
24
- y_len = len(file_uris)
25
- y_true = [0] * y_len
26
- y_pred = [0] * y_len
27
- y_idx = 0
28
- for input_file in tqdm(file_uris):
29
- res = request_api(input_file)
30
- # print('%s %s:' % (str(y_idx), str(input_file)) )
31
- # print('%s' % str(res[:3]))
32
-
33
- first_label = str(res[0]['label'])
34
- first_score = res[0]['score']
35
- # print(str(first_label))
36
- # print(str(first_score))
37
-
38
- y_true[y_idx] = first_label
39
- y_pred[y_idx] = round(first_score, 1)
40
- y_idx = y_idx + 1
41
-
42
- return y_true, y_pred
43
-
44
-
45
- # 处理命令行
46
- if __name__ == "__main__":
47
- if DEBUG: print('main, ' + str(sys.argv[1:]))
48
- if DEBUG: print('main, ' + str(len(sys.argv)))
49
-
50
- # 获取命令行参数
51
- if len(sys.argv) < 2:
52
- print("用法:python x.py <文件或通配符>")
53
- sys.exit(1)
54
-
55
- if DEBUG: print('main, batch_request_api')
56
- y_true, y_pred = batch_request_api(sys.argv[1:])
57
- if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
58
- draw2(y_true, y_pred)