Text classification framework, 利用pytorch实现高效的数据加载模块。
文本分类框架,可以完成:
另外,该框架还利用pytorch实现了高效的数据加载模块,包括以下特性:
训练文件的预处理包括:
dict
格式存放在pkl
文件中;np.array
的格式存放在pkl
文件中;运行方式:
$ python3 preprocessing.py -l --pd ./data/train.txt --ri ./data/train_idx/ --rv ./res/voc/ --re ./res/embed/ --pe ./path_to_embed_file
运行方式:
$ python3 preprocessing.py --pd ./data/test.txt --ri ./data/test_idx/
表. 参数说明
参数 | 类型 | 默认值 | 说明 |
---|---|---|---|
l | bool | False | label,是否带有标签(标志是否是训练集) |
pd | str | ./data/train.txt | path_data,训练(测试)数据路径 |
ri | str | ./data/train_idx/ | root_idx,训练数据索引文件根目录 |
rv | str | ./res/voc/ | root_voc,词表、label表根目录 |
re | str | ./res/embed/ | root_embed,embed文件根目录 |
pe | str | None | path_embed,预训练的embed文件路径,bin 或txt ;若不提供,则随机初始化 |
pt | int | 98 | percentile,构建词表时的百分位值 |
运行python3 preprocessing.py -h
可打印出帮助信息。
若预处理时root_idx
等参数使用的是默认值,则在训练时不需要设定相应参数。
运行方式:
$ CUDA_VISIBLE_DEVICES=0,1 python3 train.py --nc 2 --ml 40 --fs 3,4,5 --fn 400,300,200 --wd 64 --bs 256 -g
参数说明
参数 | 类型 | 默认值 | 说明 |
---|---|---|---|
ri | str | ./data/train_idx/ | root_idx,训练数据索引文件根目录 |
rv | str | ./res/voc/ | root_voc,词表、label表根目录 |
re | str | ./res/embed/ | root_embed,embed文件根目录 |
ml | int | 50 | max_len,句子最大长度 |
ds | float | 0.2 | dev_size,开发集占比 |
nc | int | 无 | nb_classes,分类类别数量 |
wd | int | 50 | word_dim,词向量维度 |
fs | str | 2,3,4 | filter_size,卷积核尺寸 |
fn | str | 256,256,256 | filter_num,卷积核大小 |
dp | float | 0.5 | dropout_rate,dropout rate |
ne | int | 100 | nb_epoch,迭代次数 |
mp | int | 5 | max_patience,最大耐心值,即开发集上性能超过mp次没有提示,则终止训练 |
rm | str | ./model/ | root_model,模型根目录 |
bs | int | 64 | batch_size,batch size |
g | bool | False | 是否使用GPU加速 |
nw | int | 4 | num_worker,加载数据时的线程数 |
运行python3 train.py -h
可打印出帮助信息。
运行方式:
$ CUDA_VISIBLE_DEVICES=0,1 python3 test.py --bs 256 -g --pr ./result.txt
参数 | 类型 | 默认值 | 说明 |
---|---|---|---|
ri | str | ./data/train_idx/ | root_idx,训练数据索引文件根目录 |
rv | str | ./res/voc/ | root_voc,词表、label表根目录 |
re | str | ./res/embed/ | root_embed,embed文件根目录 |
ml | int | 50 | max_len,句子最大长度 |
pm | str | 无 | path_model,模型路径 |
bs | int | 64 | batch_size,batch size |
g | bool | False | 是否使用GPU加速 |
nw | int | 4 | num_worker,加载数据时的线程数 |
pr | str | ./result.txt | 预测结果存放路径 |
运行python3 test.py -h
可打印出帮助信息。