import os import json from collections import defaultdict from nltk.translate.bleu_score import corpus_bleu import statistics import argparse import json import os import re from collections import Counter def is_duplicated(text, top_k=10, min_word_len=0): words = re.findall(r'\b\w+\b', text) word_freq = Counter(words) # 단어 최소 글자수 제한 if min_word_len > 0: for word, count in list(word_freq.items()): if len(word) <= min_word_len: del word_freq[word] if len(word_freq) == 0: return False if len(word_freq) == 1 and word_freq.most_common(1)[0][1] > 5: return word_freq.most_common(1) top_items = word_freq.most_common(top_k) frequencies = [frequency for item, frequency in top_items] mean_frequency = sum(frequencies) / len(frequencies) prev_frequency = 0 index = 0 if mean_frequency < 5: return False for item, frequency in top_items: if (prev_frequency - frequency) > mean_frequency: if index <= 1: return False # print(prev_frequency, frequency, mean_frequency, item) return top_items prev_frequency = frequency index += 1 return False def is_length_exceed(reference, generation, min_ratio=0.2, max_ratio=2): return not min_ratio <= (len(generation) / len(reference)) <= max_ratio def get_average(a): if isinstance(a, list): return round(sum(a) / len(a), 2) return a def main(): parser = argparse.ArgumentParser("argument") parser.add_argument( "directory", type=str, help="input_file", ) parser.add_argument('--detail', action='store_true', help='detail') args = parser.parse_args() # 각 파일별로 src에 대한 bleu 점수를 저장할 딕셔너리 file_src_bleu_scores = defaultdict(list) file_length_ratio = defaultdict(list) file_duplicated = defaultdict(list) file_duplicated_detail = defaultdict(list) # 디렉토리 내의 모든 파일에 대해 반복 for filename in os.listdir(args.directory): if filename.endswith('.jsonl'): # JSONL 파일인 경우에만 처리 file_path = os.path.join(args.directory, filename) with open(file_path, 'r', encoding='utf-8') as file: for index, line in enumerate(file): data = json.loads(line) src = data['src'] bleu_score = data['bleu'] file_src_bleu_scores[filename].append(bleu_score) # check_length reference_length = len(data['reference']) generation_length = len(data['generation']) file_length_ratio[filename].append(round(generation_length / reference_length, 1)) # check duplication word_count = is_duplicated(data['generation']) file_duplicated[filename].append(0 if word_count is False else 1) if word_count != False: file_duplicated_detail[filename].append({'index':index, 'count':word_count,'generation':data['generation']}) sorted_items = sorted(file_src_bleu_scores.items(), key=lambda x: statistics.mean(x[1])) # 각 파일별로 src에 대한 bleu 평균 계산 print('bleu scores') for filename, src_bleu_scores in sorted_items: avg_bleu = sum(src_bleu_scores) / len(src_bleu_scores) length_raio=[] cur_length_ratio = file_length_ratio[filename] ratio_mean = round(statistics.mean(cur_length_ratio), 1) for index, ratio in enumerate(cur_length_ratio): if ratio < 0.2 or ratio > 2.0: length_raio.append((index,ratio)) print(f"{filename}: {avg_bleu:.2f}, out_of_range_count={len(length_raio)}, duplicate={sum(file_duplicated[filename])}") if args.detail: print(f'\t error length:{length_raio}') if args.detail: print(f"\t duplication") for info in file_duplicated_detail[filename]: print('\t\t', info) if __name__ == "__main__": main()