|
--- |
|
license: apache-2.0 |
|
metrics: |
|
- accuracy |
|
pipeline_tag: image-classification |
|
--- |
|
# Siamese CNN |
|
复现[Argot: Generating Adversarial Readable Chinese Texts IJCAI 2020](https://www.ijcai.org/Proceedings/2020/351) 字形变换的相似结构汉字筛选 |
|
## 介绍 |
|
|
|
基于CNN架构,采用孪生网络训练方式,对输入的汉字对进行编码并计算其欧式距离作为汉字字形相似度度量 |
|
|
|
![image/png](https://cdn-uploads.huggingface.co/production/uploads/637cd8039a5217b88b72b71c/e5w-gARTB_WgV3Ixg0K-Y.png) |
|
|
|
## 架构 |
|
三层Conv2D 大小(Input_channel, output_channel, filter_size)= (3,64,8),(64,128,8),(128,128,8) |
|
每层卷积层后添加MaxPool(2) |
|
lr = 0.002 |
|
## 数据集 |
|
汉字来源:https://github.com/zzboy/chinese |
|
采用pygame生成图片数据,默认采用黑体字体,图片大小为200*200 |
|
上述汉字每行作为相似字符,按照7:3划分数据集 |
|
并参考https://github.com/avilash/pytorch-siamese-triplet 生成三元组训练数据、测试数据,实际训练、测试时采用50000对、10000对三元组数据 |
|
## 评估 |
|
loss = MarginRankingLoss(margin=1) |
|
| 0% of margin | 20% of margin | 50% of margin | loss |epoch| |
|
| :--------------- | :---------------------- | :---|:---|:-----| |
|
|0.9012 |0.7998 | 0.5700| 0.4674| 10| |
|
|
|
0% of margin 相当于准确率 |
|
## 使用 |
|
采用Pytorch加载,一般用加载一个CNN模型就可以使用,注意删除state_dict中的字典名字 |
|
``` |
|
model_dict = torch.load('./checkpoint.pth')['state_dict'] |
|
model_dict_mod = {} |
|
for key, value in model_dict.items(): |
|
new_key = '.'.join(key.split('.')[1:]) |
|
model_dict_mod[new_key] = value |
|
self.model.load_state_dict(model_dict_mod) |
|
``` |
|
## 文件介绍 |
|
prepare_data.py生成数据集,将汉字转换为图片,默认黑体字体,也可以用别的,从C://Windows/Fonts Windows系统上下载 |
|
|
|
|