File size: 4,894 Bytes
96fe658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import json
import requests
from swift.version import __version__
# 打标
class ModelTag(object):
_URL = os.environ.get('MODEL_TAG_URL', None)
# 模型测试结果
BATCH_COMMIT_RESULT_URL = f'{_URL}/batchCommitResult'
# 测试阶段完成
BATCH_REFRESH_STAGE_URL = f'{_URL}/batchRefreshStage'
# query_model_stage
QUERY_MODEL_STAGE_URL = f'{_URL}/queryModelStage'
HEADER = {'Content-Type': 'application/json'}
# 检测结果
MODEL_SKIP = 0
MODEL_FAIL = 1
MODEL_PASS = 2
class ItemResult(object):
def __init__(self):
self.result = 0
self.name = ''
self.info = ''
def to_json(self):
return {'name': self.name, 'result': self.result, 'info': self.info}
def __init__(self):
self.job_name = ''
self.job_id = ''
self.model = ''
self.sdk_version = ''
self.image_version = ''
self.domain = ''
self.task = ''
self.source = ''
self.stage = ''
# ItemResult list
self.item_result = []
# 发送请求
def _post_request(self, url, param):
try:
logging.info(url + ' query: ' + str(json.dumps(param, ensure_ascii=False)))
res = requests.post(url=url, headers=self.HEADER, data=json.dumps(param, ensure_ascii=False).encode('utf8'))
if res.status_code == 200:
logging.info(f'{url} post结果: ' + res.text)
res_json = json.loads(res.text)
if int(res_json['errorCode']) == 200:
return res_json['content']
else:
logging.error(res.text)
else:
logging.error(res.text)
except Exception as e:
logging.error(e)
return None
# 提交模型测试结果
def batch_commit_result(self):
try:
param = {
'sdkVersion':
self.sdk_version,
'imageVersion':
self.image_version,
'source':
self.source,
'jobName':
self.job_name,
'jobId':
self.job_id,
'modelList': [{
'model': self.model,
'domain': self.domain,
'task': self.task,
'itemResult': self.item_result
}]
}
return self._post_request(self.BATCH_COMMIT_RESULT_URL, param)
except Exception as e:
logging.error(e)
return
# 测试阶段完成
def batch_refresh_stage(self):
try:
param = {
'sdkVersion': self.sdk_version,
'imageVersion': self.image_version,
'source': self.source,
'stage': self.stage,
'modelList': [{
'model': self.model,
'domain': self.domain,
'task': self.task
}]
}
return self._post_request(self.BATCH_REFRESH_STAGE_URL, param)
except Exception as e:
logging.error(e)
return
# 查询模型某个阶段的最新测试结果(只返回单个结果
def query_model_stage(self):
try:
param = {
'sdkVersion': self.sdk_version,
'model': self.model,
'stage': self.stage,
'imageVersion': self.image_version
}
return self._post_request(self.QUERY_MODEL_STAGE_URL, param)
except Exception as e:
logging.error(e)
return None
# 提交模型UT测试结果
"""
model_tag = ModelTag()
model_tag.model = "XXX"
model_tag.sdk_version = "0.3.7"
model_tag.domain = "nlp"
model_tag.task = "word-segmentation"
item = model_tag.ItemResult()
item.result = model_tag.MODEL_PASS
item.name = "ALL"
item.info = ""
model_tag.item_result.append(item.to_json())
"""
def commit_ut_result(self):
if self._URL is not None and self._URL != '':
self.job_name = 'UT'
self.source = 'dev'
self.stage = 'integration'
self.batch_commit_result()
self.batch_refresh_stage()
def commit_model_ut_result(model_name, ut_result):
model_tag = ModelTag()
model_tag.model = model_name.replace('damo/', '')
model_tag.sdk_version = __version__
# model_tag.domain = ""
# model_tag.task = ""
item = model_tag.ItemResult()
item.result = ut_result
item.name = 'ALL'
item.info = ''
model_tag.item_result.append(item.to_json())
model_tag.commit_ut_result()
|