dongjun-Lee/text-classification-models-tf
Tensorflow implementations of Text Classification Models.
repo name | dongjun-Lee/text-classification-models-tf |
repo link | https://github.com/dongjun-Lee/text-classification-models-tf |
homepage | |
language | Python |
size (curr.) | 11 kB |
stars (curr.) | 459 |
created | 2018-07-15 |
license | |
Text Classification Models with Tensorflow
Tensorflow implementation of Text Classification Models.
Implemented Models:
- Word-level CNN [paper]
- Character-level CNN [paper]
- Very Deep CNN [paper]
- Word-level Bidirectional RNN
- Attention-Based Bidirectional RNN [paper]
- RCNN [paper]
Semi-supervised text classification(Transfer learning) models are implemented at [dongjun-Lee/transfer-learning-text-tf].
Requirements
- Python3
- Tensorflow
- pip install -r requirements.txt
Usage
Train
To train classification models for dbpedia dataset,
$ python train.py --model="<MODEL>"
(<Model>: word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn)
Test
To test classification accuracy for test data after training,
$ python test.py --model="<TRAINED_MODEL>"
Sample Test Results
Trained and tested with dbpedia dataset. (dbpedia_csv/train.csv
, dbpedia_csv/test.csv
)
Model | WordCNN | CharCNN | VDCNN | WordRNN | AttentionRNN | RCNN | *SA-LSTM | *LM-LSTM |
---|---|---|---|---|---|---|---|---|
Accuracy | 98.42% | 98.05% | 97.60% | 98.57% | 98.61% | 98.68% | 98.88% | 98.86% |
(SA-LSTM and LM-LSTM are implemented at [dongjun-Lee/transfer-learning-text-tf].)
Models
1. Word-level CNN
Implementation of Convolutional Neural Networks for Sentence Classification.
2. Char-level CNN
Implementation of Character-level Convolutional Networks for Text Classification.
3. Very Deep CNN (VDCNN)
Implementation of Very Deep Convolutional Networks for Text Classification.
4. Word-level Bi-RNN
Bi-directional RNN for Text Classification.
- Embedding layer
- Bidirectional RNN layer
- Concat all the outputs from RNN layer
- Fully-connected layer
5. Attention-Based Bi-RNN
Implementation of Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification.
6. RCNN
Implementation of Recurrent Convolutional Neural Networks for Text Classification.