Skip to content

Token and Sentence Level Classification with Google's BERT (TensorFlow)

Notifications You must be signed in to change notification settings

26hzhang/bert_classification

Repository files navigation

BERT Classification

Use google BERT (tensorflow-based) to do token-level and sentence-level classification.

Requirements

  • tensorflow>=1.11.0 (or tensorflow-gpu>=1.11.0)
  • numpy>=1.14.4
  • official tensorflow based bert code, get the code https://github.com/google-research/bert.git and place it under this repository.
  • pre-trained bert models (according to the tasks), download and place to the checkpoint/ directory.
bert_classification/
    |____ bert/
    |____ bert_ckpt/
    |____ checkpoint/
    |____ datasets/
    |____ .gitignore
    |____ conlleval.pl
    |____ data_cls_helper.py
    |____ data_seq_helper.py
    |____ README.md
    |____ run_sequence_tagger.py
    |____ run_text_classifier.py

Dataset Overview

Token level classification datasets (POS, Chunk and NER):

Dataset Language Classes Training tokens Dev tokens Test tokens
CoNLL2000 Chunk (en) English 23 211,727 N.A. 47,377
CoNLL2002 NER (es) Spanish 9 207,484 (18,797) 51,645 (4,351) 52098 (3,558)
CoNLL2002 NER (nl) Dutch 9 20,2931 (13,344) 37,761 (2,616) 68,994 (3,941)
CoNLL2003 NER (en) English 9 20,4567 (23,499) 51,578 (5,942) 46,666 (5,648)
CoNLL2003 NER (de 1) German 9 208,836 (16,839) 51,444 (6,588) 51,943 (5,171)
GermEval2014 NER (de 2) German 25 452,853 (42,089) 41,653 (3,960) 96,499 (8,969)
Chinese NER 1 (zh 1) Chinese 21 1,044,967 (311,637) 86,454 (24,444) 119,467 (38,854)
Chinese NER 2 (Zh 2) Chinese 7 979,180 (110,093) 109,870 (12,059) 219,197 (25,012)

All the lines in those datasets are convert to (word, label) pairs with \t as separator and drop all the lines start with -DOCSTART- and other undesired lines, while the label is in BIO2 format (Begin, Inside, Others).

Sentence level classification datasets:

Dataset Classes Average sentence length Train size Dev size Test size
CR 2 19 3,395 N.A. 377
MR 2 20 9,595 N.A. 1,066
SST2 2 11 67,349 872 1,821
SST5 5 18 8,544 1,101 2,210
SUBJ 2 23 9,000 N.A. 1,000
TREC 6 10 5,452 N.A. 500

All the datasets are converted to utf-8 format via iconv -f <src format> -t utf-8 filename -o save_name. For the SUBJ, MR and CR datasets, 90% for train, 10% for test, while the dev dataset is the duplicate of test dataset. For TREC dataset, the dev dataset is the duplicate of test dataset.

Natural language inference (sentence pair classification) datasets:

Dataset Classes Train size Dev size Test size
MRPC 2 4,077 1,726 1,726
SICK 3 4,501 501 4,928
SNLI 3 549,367 9,842 9,824
CoLA 2 8,551 527 516

MNLI and XNLI datasets are implemented by the official BERT already, see run_classifier.py in [google-research/bert].

Usage

For token-level classification, run:

python3 run_sequence_tagger.py --task_name ner  \  # task name
                               --data_dir datasets/CoNLL2003_en  \  # dataset folder
                               --output_dir checkpoint/conll2003_en  \  # path to save outputs and trained params
                               --bert_config_file bert_ckpt/cased_L-12_H-768_A-12/bert_config.json  \  # pre-trained BERT configs
                               --init_checkpoint bert_ckpt/cased_L-12_H-768_A-12/bert_model.ckpt  \  # pre-trained BERT params
                               --vocab_file bert_ckpt/cased_L-12_H-768_A-12/vocab.txt  \  # BERT vocab file
                               --do_lower_case False  \  # whether lowercase the input tokens
                               --max_seq_length 128  \  # maximal sequence allowed
                               --do_train True  \  # if training
                               --do_eval True  \  # if evaluation
                               --do_predict True  \  # if prediction
                               --batch_size 32  \  # batch_size, change to `16` if OOM happens
                               --num_train_epochs 6  \  # number of epochs
                               --use_crf True  # if use CRF for decoding

The token-level classification model contains two modules, one is using CRF for decode while another use a classifier directly. The output sequence of bert model is first fed into a dense layer and then decode by CRF/classifier, no intermediate RNN layers are used.

For sentence-level classification, run:

python3 run_text_classifier.py --task_name mrpc  \  # task name
                               --data_dir datasets/MRPC  \  # dataset folder
                               --output_dir checkpoint/mrpc  \  # path to save outputs and trained params
                               --bert_config_file bert_ckpt/uncased_L-12_H-768_A-12/bert_config.json  \  # pre-trained BERT configs
                               --init_checkpoint bert_ckpt/uncased_L-12_H-768_A-12/bert_model.ckpt  \  # pre-trained BERT params
                               --vocab_file bert_ckpt/uncased_L-12_H-768_A-12/vocab.txt  \  # BERT vocab file
                               --do_lower_case True  \  # whether lowercase the input tokens
                               --max_seq_length 128  \  # maximal sequence allowed
                               --do_train True  \  # if training
                               --do_eval True  \  # if evaluation
                               --do_predict True  \  # if prediction
                               --batch_size 32  \  # batch_size, change to `16` if OOM happens
                               --num_train_epochs 6  # number of epochs

The sentence-level classification directly take the pooled output of bert model and feed it into a classifier for decode.

Experiment Results

All the experiments are running on 1 GeForce GTX 1080 Ti GPU.

Token level classification datasets

Dataset en Chunk es NER nl NER en NER de NER 1 de NER 2 zh NER 1 zh NER 2
Precision (%) 96.8 89.0 89.8 92.0 82.0 86.2 77.9 95.7
Recall (%) 96.4 88.6 90.0 90.8 86.4 85.4 73.1 95.7
F1 (%) 96.6 88.8 89.9 91.4 84.2 85.8 75.5 95.7

CoNLL2002 Spanish/Dutch, CoNLL2003 German NER and GermEval2014 German NER use multi_cased_L-12_H-768_A-12.zip pre-trained model (base, multilingual, cased)

CoNLL2000 Chunk and CoNLL2003 NER utilize cased_L-12_H-768_A-12.zip pre-trained model (base, English, cased)

Chinese NER uses chinese_L-12_H-768_A-12.zip pre-trained model (base, Chinese).

The testing results on CoNLL-2003 English NER are lower than the reported score of the paper (91.4% v.s. 92.4%). As the paper says a 0.2% difference is reasonable, however, I got 1.0% error. I think maybe some tricks are missing, for example, the parameters setting in output classifier or data pre-processing strategies.

Sentence level classification datasets

Dataset CR MR SST2 SST5 SUBJ TREC
Dev Accuracy (%) N.A. N.A. 91.3 50.1 N.A. N.A.
Test Accuracy (%) 89.2 85.4 93.5 53.3 97.3 96.6

All the tasks use uncased_L-12_H-768_A-12.zip pre-trained model (base, English, uncased).

Natural language inference datasets

Dataset MRPC SICK SNLI CoLA
Dev Accuracy (%) N.A. 86.4 91.1 83.1
Test Accuracy (%) 84.7 87.0 90.7 78.9

All the tasks use uncased_L-12_H-768_A-12.zip pre-trained model (base, English, uncased).

The results may differ from the reported results, since I do not use the GLUE version datasets.

Reference

About

Token and Sentence Level Classification with Google's BERT (TensorFlow)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published