sparse / ms-swift /tests /model_tag.py
Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import json
import requests
from swift.version import __version__
# 打标
class ModelTag(object):
_URL = os.environ.get('MODEL_TAG_URL', None)
# 模型测试结果
BATCH_COMMIT_RESULT_URL = f'{_URL}/batchCommitResult'
# 测试阶段完成
BATCH_REFRESH_STAGE_URL = f'{_URL}/batchRefreshStage'
# query_model_stage
QUERY_MODEL_STAGE_URL = f'{_URL}/queryModelStage'
HEADER = {'Content-Type': 'application/json'}
# 检测结果
MODEL_SKIP = 0
MODEL_FAIL = 1
MODEL_PASS = 2
class ItemResult(object):
def __init__(self):
self.result = 0
self.name = ''
self.info = ''
def to_json(self):
return {'name': self.name, 'result': self.result, 'info': self.info}
def __init__(self):
self.job_name = ''
self.job_id = ''
self.model = ''
self.sdk_version = ''
self.image_version = ''
self.domain = ''
self.task = ''
self.source = ''
self.stage = ''
# ItemResult list
self.item_result = []
# 发送请求
def _post_request(self, url, param):
try:
logging.info(url + ' query: ' + str(json.dumps(param, ensure_ascii=False)))
res = requests.post(url=url, headers=self.HEADER, data=json.dumps(param, ensure_ascii=False).encode('utf8'))
if res.status_code == 200:
logging.info(f'{url} post结果: ' + res.text)
res_json = json.loads(res.text)
if int(res_json['errorCode']) == 200:
return res_json['content']
else:
logging.error(res.text)
else:
logging.error(res.text)
except Exception as e:
logging.error(e)
return None
# 提交模型测试结果
def batch_commit_result(self):
try:
param = {
'sdkVersion':
self.sdk_version,
'imageVersion':
self.image_version,
'source':
self.source,
'jobName':
self.job_name,
'jobId':
self.job_id,
'modelList': [{
'model': self.model,
'domain': self.domain,
'task': self.task,
'itemResult': self.item_result
}]
}
return self._post_request(self.BATCH_COMMIT_RESULT_URL, param)
except Exception as e:
logging.error(e)
return
# 测试阶段完成
def batch_refresh_stage(self):
try:
param = {
'sdkVersion': self.sdk_version,
'imageVersion': self.image_version,
'source': self.source,
'stage': self.stage,
'modelList': [{
'model': self.model,
'domain': self.domain,
'task': self.task
}]
}
return self._post_request(self.BATCH_REFRESH_STAGE_URL, param)
except Exception as e:
logging.error(e)
return
# 查询模型某个阶段的最新测试结果(只返回单个结果
def query_model_stage(self):
try:
param = {
'sdkVersion': self.sdk_version,
'model': self.model,
'stage': self.stage,
'imageVersion': self.image_version
}
return self._post_request(self.QUERY_MODEL_STAGE_URL, param)
except Exception as e:
logging.error(e)
return None
# 提交模型UT测试结果
"""
model_tag = ModelTag()
model_tag.model = "XXX"
model_tag.sdk_version = "0.3.7"
model_tag.domain = "nlp"
model_tag.task = "word-segmentation"
item = model_tag.ItemResult()
item.result = model_tag.MODEL_PASS
item.name = "ALL"
item.info = ""
model_tag.item_result.append(item.to_json())
"""
def commit_ut_result(self):
if self._URL is not None and self._URL != '':
self.job_name = 'UT'
self.source = 'dev'
self.stage = 'integration'
self.batch_commit_result()
self.batch_refresh_stage()
def commit_model_ut_result(model_name, ut_result):
model_tag = ModelTag()
model_tag.model = model_name.replace('damo/', '')
model_tag.sdk_version = __version__
# model_tag.domain = ""
# model_tag.task = ""
item = model_tag.ItemResult()
item.result = ut_result
item.name = 'ALL'
item.info = ''
model_tag.item_result.append(item.to_json())
model_tag.commit_ut_result()