3v324v23's picture
add
c310e19
raw
history blame
19.8 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# encoding=utf8
from collections import namedtuple
import rrc_evaluation_funcs_total_text as rrc_evaluation_funcs
import importlib
from prepare_results import prepare_results_for_evaluation
def evaluation_imports():
"""
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
"""
return {
'Polygon':'plg',
'numpy':'np'
}
def default_evaluation_params():
"""
default_evaluation_params: Default parameters to use for the validation and evaluation.
"""
return {
'IOU_CONSTRAINT' :0.5,
'AREA_PRECISION_CONSTRAINT' :0.5,
'WORD_SPOTTING' :False,
'MIN_LENGTH_CARE_WORD' :3,
'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
'CRLF':False, # Lines are delimited by Windows CRLF format
'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
}
def validate_data(gtFilePath, submFilePath, evaluationParams):
"""
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
Validates also that there are no missing files in the folder.
If some error detected, the method raises the error
"""
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
#Validate format of GroundTruth
for k in gt:
rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
#Validate format of results
for k in subm:
if (k in gt) == False :
raise Exception("The sample %s not present in GT" %k)
rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
"""
Method evaluate_method: evaluate method and returns the results
Results. Dictionary with the following values:
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
"""
for module,alias in evaluation_imports().items():
globals()[alias] = importlib.import_module(module)
def polygon_from_points(points,correctOffset=False):
"""
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
"""
resBoxes=np.empty([1,len(points)],dtype='int32')
for i in range(int(len(points) / 2)):
resBoxes[0, i] = int(points[2*i])
resBoxes[0, int(len(points) / 2) + i] = int(points[2*i+1])
pointMat = resBoxes[0].reshape([2,-1]).T
return plg.Polygon( pointMat)
def rectangle_to_polygon(rect):
resBoxes=np.empty([1,8],dtype='int32')
resBoxes[0,0]=int(rect.xmin)
resBoxes[0,4]=int(rect.ymax)
resBoxes[0,1]=int(rect.xmin)
resBoxes[0,5]=int(rect.ymin)
resBoxes[0,2]=int(rect.xmax)
resBoxes[0,6]=int(rect.ymin)
resBoxes[0,3]=int(rect.xmax)
resBoxes[0,7]=int(rect.ymax)
pointMat = resBoxes[0].reshape([2,4]).T
return plg.Polygon( pointMat)
def rectangle_to_points(rect):
points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
return points
def get_union(pD,pG):
areaA = pD.area();
areaB = pG.area();
return areaA + areaB - get_intersection(pD, pG);
def get_intersection_over_union(pD,pG):
try:
return get_intersection(pD, pG) / get_union(pD, pG);
except:
return 0
def get_intersection(pD,pG):
pInt = pD & pG
if len(pInt) == 0:
return 0
return pInt.area()
def compute_ap(confList, matchList,numGtCare):
correct = 0
AP = 0
if len(confList)>0:
confList = np.array(confList)
matchList = np.array(matchList)
sorted_ind = np.argsort(-confList)
confList = confList[sorted_ind]
matchList = matchList[sorted_ind]
for n in range(len(confList)):
match = matchList[n]
if match:
correct += 1
AP += float(correct)/(n + 1)
if numGtCare>0:
AP /= numGtCare
return AP
def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
if onlyRemoveFirstLastCharacterGT:
#special characters in GT are allowed only at initial or final position
if (transGt==transDet):
return True
if specialCharacters.find(transGt[0])>-1:
if transGt[1:]==transDet:
return True
if specialCharacters.find(transGt[-1])>-1:
if transGt[0:len(transGt)-1]==transDet:
return True
if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
if transGt[1:len(transGt)-1]==transDet:
return True
return False
else:
#Special characters are removed from the begining and the end of both Detection and GroundTruth
while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
transGt = transGt[1:]
while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
transDet = transDet[1:]
while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
transGt = transGt[0:len(transGt)-1]
while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
transDet = transDet[0:len(transDet)-1]
return transGt == transDet
def include_in_dictionary(transcription):
"""
Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
"""
#special case 's at final
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
transcription = transcription[0:len(transcription)-2]
#hypens at init or final of the word
transcription = transcription.strip('-');
specialCharacters = "'!?.:,*\"()·[]/";
for character in specialCharacters:
transcription = transcription.replace(character,' ')
transcription = transcription.strip()
if len(transcription) != len(transcription.replace(" ","")) :
return False;
if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
return False;
notAllowed = "×÷·";
range1 = [ ord(u'a'), ord(u'z') ]
range2 = [ ord(u'A'), ord(u'Z') ]
range3 = [ ord(u'À'), ord(u'ƿ') ]
range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
range6 = [ ord(u'-'), ord(u'-') ]
for char in transcription :
charCode = ord(char)
if(notAllowed.find(char) != -1):
return False
valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
if valid == False:
return False
return True
def include_in_dictionary_transcription(transcription):
"""
Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
"""
#special case 's at final
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
transcription = transcription[0:len(transcription)-2]
#hypens at init or final of the word
transcription = transcription.strip('-');
specialCharacters = "'!?.:,*\"()·[]/";
for character in specialCharacters:
transcription = transcription.replace(character,' ')
transcription = transcription.strip()
return transcription
perSampleMetrics = {}
matchedSum = 0
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
numGlobalCareGt = 0;
numGlobalCareDet = 0;
arrGlobalConfidences = [];
arrGlobalMatches = [];
for resFile in gt:
gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
if (gtFile is None) :
raise Exception("The file %s is not UTF-8" %resFile)
recall = 0
precision = 0
hmean = 0
detCorrect = 0
iouMat = np.empty([1,1])
gtPols = []
detPols = []
gtTrans = []
detTrans = []
gtPolPoints = []
detPolPoints = []
gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
detMatchedNums = []
pairs = []
arrSampleConfidences = [];
arrSampleMatch = [];
sampleAP = 0;
evaluationLog = ""
pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
dontCare = transcription == "###"
if evaluationParams['LTRB']:
gtRect = Rectangle(*points)
gtPol = rectangle_to_polygon(gtRect)
else:
gtPol = polygon_from_points(points)
gtPols.append(gtPol)
gtPolPoints.append(points)
#On word spotting we will filter some transcriptions with special characters
if evaluationParams['WORD_SPOTTING'] :
if dontCare == False :
if include_in_dictionary(transcription) == False :
dontCare = True
else:
transcription = include_in_dictionary_transcription(transcription)
gtTrans.append(transcription)
if dontCare:
gtDontCarePolsNum.append( len(gtPols)-1 )
evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
if resFile in subm:
detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
if evaluationParams['LTRB']:
detRect = Rectangle(*points)
detPol = rectangle_to_polygon(detRect)
else:
detPol = polygon_from_points(points)
detPols.append(detPol)
detPolPoints.append(points)
detTrans.append(transcription)
if len(gtDontCarePolsNum)>0 :
for dontCarePol in gtDontCarePolsNum:
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol,detPol)
pdDimensions = detPol.area()
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
detDontCarePolsNum.append( len(detPols)-1 )
break
evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
if len(gtPols)>0 and len(detPols)>0:
#Calculate IoU and precision matrixs
outputShape=[len(gtPols),len(detPols)]
iouMat = np.empty(outputShape)
gtRectMat = np.zeros(len(gtPols),np.int8)
detRectMat = np.zeros(len(detPols),np.int8)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
pG = gtPols[gtNum]
pD = detPols[detNum]
iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
#detection matched only if transcription is equal
if evaluationParams['WORD_SPOTTING']:
correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
else:
correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
detCorrect += (1 if correct else 0)
if correct:
detMatchedNums.append(detNum)
pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
if evaluationParams['CONFIDENCES']:
for detNum in range(len(detPols)):
if detNum not in detDontCarePolsNum :
#we exclude the don't care detections
match = detNum in detMatchedNums
arrSampleConfidences.append(confidencesList[detNum])
arrSampleMatch.append(match)
arrGlobalConfidences.append(confidencesList[detNum]);
arrGlobalMatches.append(match);
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare >0 else float(1)
sampleAP = precision
else:
recall = float(detCorrect) / numGtCare
precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
if evaluationParams['CONFIDENCES']:
sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
matchedSum += detCorrect
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
perSampleMetrics[resFile] = {
'precision':precision,
'recall':recall,
'hmean':hmean,
'pairs':pairs,
'AP':sampleAP,
'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
'gtPolPoints':gtPolPoints,
'detPolPoints':detPolPoints,
'gtTrans':gtTrans,
'detTrans':detTrans,
'gtDontCare':gtDontCarePolsNum,
'detDontCare':detDontCarePolsNum,
'evaluationParams': evaluationParams,
'evaluationLog': evaluationLog
}
# Compute AP
AP = 0
if evaluationParams['CONFIDENCES']:
AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
return resDict;
if __name__=='__main__':
'''
results_dir: result directory
score_det: score of detection bounding box
score_rec: score of the mask recognition branch
score_rec_seq: score of the sequence recognition branch
lexicon_type: 1 for generic; 2 for weak; 3 for strong
'''
results_dir = '../../../output/mixtrain/inference/total_text_test/model_0250000_1000_results/'
score_det = 0.05
score_rec = 0.5
use_lexicon = False
score_rec_seq = 0.9
# use_lexicon = True
# score_rec_seq = 0.8
evaluate_result_path = prepare_results_for_evaluation(results_dir,
use_lexicon=use_lexicon, cache_dir='./cache_files',
score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq)
p = {
'g': "../gt.zip",
'o': "./cache_files",
's': evaluate_result_path
}
rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)