项目作者: Hironsan

项目描述 :
Implementation of CRF layer in Keras.
高级语言: Python
项目地址: git://github.com/Hironsan/keras-crf-layer.git
创建时间: 2017-09-03T12:54:38Z
项目社区:https://github.com/Hironsan/keras-crf-layer

开源协议:MIT License

下载


Keras-CRF-Layer

The Keras-CRF-Layer module implements a linear-chain CRF layer for learning to predict tag sequences.
This variant of the CRF is factored into unary potentials for every element in the sequence and binary potentials for every transition between output tags.

Usage

Below is an example of the API, which learns a CRF for some random data.
The linear layer in the example can be replaced by any neural network.

  1. import numpy as np
  2. from keras.layers import Embedding, Input
  3. from keras.models import Model
  4. from crf import CRFLayer
  5. # Hyperparameter settings.
  6. vocab_size = 20
  7. n_classes = 11
  8. batch_size = 2
  9. maxlen = 2
  10. # Random features.
  11. x = np.random.randint(1, vocab_size, size=(batch_size, maxlen))
  12. # Random tag indices representing the gold sequence.
  13. y = np.random.randint(n_classes, size=(batch_size, maxlen))
  14. y = np.eye(n_classes)[y]
  15. # All sequences in this example have the same length, but they can be variable in a real model.
  16. s = np.asarray([maxlen] * batch_size, dtype='int32')
  17. # Build an example model.
  18. word_ids = Input(batch_shape=(batch_size, maxlen), dtype='int32')
  19. sequence_lengths = Input(batch_shape=[batch_size, 1], dtype='int32')
  20. word_embeddings = Embedding(vocab_size, n_classes)(word_ids)
  21. crf = CRFLayer()
  22. pred = crf(inputs=[word_embeddings, sequence_lengths])
  23. model = Model(inputs=[word_ids, sequence_lengths], outputs=[pred])
  24. model.compile(loss=crf.loss, optimizer='sgd')
  25. # Train first 1 batch.
  26. model.train_on_batch([x, s], y)
  27. # Save the model
  28. model.save('model.h5')

Model loading

When you want to load a saved model that has a crf output, then loading
the model with ‘keras.models.load_model’ won’t work properly because
the reference of the loss function to the transition parameters is lost. To
fix this, you need to use the parameter ‘custom_objects’ as follows:

  1. from keras.models import load_model
  2. from crf import create_custom_objects
  3. model = load_model('model.h5', custom_objects=create_custom_objects())