Source code for tf_crnn.model

#!/usr/bin/env python
__author__ = 'solivr'
__license__ = "GPL"

import tensorflow as tf
from tensorflow.contrib.rnn import BasicLSTMCell
from .decoding import get_words_from_chars
from .config import Params, TrainingParams, CONST

def weightVar(shape, mean=0.0, stddev=0.02, name='weights'):
    init_w = tf.truncated_normal(shape=shape, mean=mean, stddev=stddev)
    return tf.Variable(init_w, name=name)

def biasVar(shape, value=0.0, name='bias'):
    init_b = tf.constant(value=value, shape=shape)
    return tf.Variable(init_b, name=name)

def conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME', name=None):
    return tf.nn.conv2d(input, filter, strides=strides, padding=padding, name=name)

[docs]def deep_cnn(input_imgs: tf.Tensor, input_channels: int, is_training: bool, summaries: bool=True) -> tf.Tensor: """ CNN part of the CRNN network. :param input_imgs: input images [B, H, W, C] :param input_channels: input channels, 1 for greyscale images, 3 for RGB color images :param is_training: flag to indicate training or not :param summaries: flag to enable bias and weight histograms to be visualized in Tensorboard :return: tensor of shape [batch, final_width, final_height x final_features] """ assert (input_channels in [1, 3]) input_tensor = input_imgs # Following source code, not paper with tf.variable_scope('deep_cnn'): # - conv1 - maxPool2x2 with tf.variable_scope('layer1'): W = weightVar([3, 3, input_channels, 64]) b = biasVar([64]) conv = conv2d(input_tensor, W, name='conv') out = tf.nn.bias_add(conv, b) conv1 = tf.nn.relu(out) pool1 = tf.nn.max_pool(conv1, [1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool') if summaries: weights = [var for var in tf.global_variables() if == 'deep_cnn/layer1/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_cnn/layer1/bias:0'][0] tf.summary.histogram('bias', bias) # - conv2 - maxPool 2x2 with tf.variable_scope('layer2'): W = weightVar([3, 3, 64, 128]) b = biasVar([128]) conv = conv2d(pool1, W) out = tf.nn.bias_add(conv, b) conv2 = tf.nn.relu(out) pool2 = tf.nn.max_pool(conv2, [1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1') if summaries: weights = [var for var in tf.global_variables() if == 'deep_cnn/layer2/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_cnn/layer2/bias:0'][0] tf.summary.histogram('bias', bias) # - conv3 - w/batch-norm (as source code, not paper) with tf.variable_scope('layer3'): W = weightVar([3, 3, 128, 256]) b = biasVar([256]) conv = conv2d(pool2, W) out = tf.nn.bias_add(conv, b) b_norm = tf.layers.batch_normalization(out, axis=-1, training=is_training, name='batch-norm') conv3 = tf.nn.relu(b_norm, name='ReLU') if summaries: weights = [var for var in tf.global_variables() if == 'deep_cnn/layer3/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_cnn/layer3/bias:0'][0] tf.summary.histogram('bias', bias) # - conv4 - maxPool 2x1 with tf.variable_scope('layer4'): W = weightVar([3, 3, 256, 256]) b = biasVar([256]) conv = conv2d(conv3, W) out = tf.nn.bias_add(conv, b) conv4 = tf.nn.relu(out) pool4 = tf.nn.max_pool(conv4, [1, 2, 2, 1], strides=[1, 2, 1, 1], padding='SAME', name='pool4') if summaries: weights = [var for var in tf.global_variables() if == 'deep_cnn/layer4/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_cnn/layer4/bias:0'][0] tf.summary.histogram('bias', bias) # - conv5 - w/batch-norm with tf.variable_scope('layer5'): W = weightVar([3, 3, 256, 512]) b = biasVar([512]) conv = conv2d(pool4, W) out = tf.nn.bias_add(conv, b) b_norm = tf.layers.batch_normalization(out, axis=-1, training=is_training, name='batch-norm') conv5 = tf.nn.relu(b_norm) if summaries: weights = [var for var in tf.global_variables() if == 'deep_cnn/layer5/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_cnn/layer5/bias:0'][0] tf.summary.histogram('bias', bias) # - conv6 - maxPool 2x1 (as source code, not paper) with tf.variable_scope('layer6'): W = weightVar([3, 3, 512, 512]) b = biasVar([512]) conv = conv2d(conv5, W) out = tf.nn.bias_add(conv, b) conv6 = tf.nn.relu(out) pool6 = tf.nn.max_pool(conv6, [1, 2, 2, 1], strides=[1, 2, 1, 1], padding='SAME', name='pool6') if summaries: weights = [var for var in tf.global_variables() if == 'deep_cnn/layer6/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_cnn/layer6/bias:0'][0] tf.summary.histogram('bias', bias) # - conv 7 - w/batch-norm (as source code, not paper) with tf.variable_scope('layer7'): W = weightVar([2, 2, 512, 512]) b = biasVar([512]) conv = conv2d(pool6, W, padding='VALID') out = tf.nn.bias_add(conv, b) b_norm = tf.layers.batch_normalization(out, axis=-1, training=is_training, name='batch-norm') conv7 = tf.nn.relu(b_norm) if summaries: weights = [var for var in tf.global_variables() if == 'deep_cnn/layer7/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_cnn/layer7/bias:0'][0] tf.summary.histogram('bias', bias) cnn_net = conv7 with tf.variable_scope('Reshaping_cnn'): # shape = cnn_net.get_shape().as_list() shape = tf.shape(cnn_net) # [batch, height, width, features] transposed = tf.transpose(cnn_net, perm=[0, 2, 1, 3], name='transposed') # [batch, width, height, features] conv_reshaped = tf.reshape(transposed, [shape[0], shape[2], shape[1] * shape[3]], name='reshaped') # [batch, width, height x features] # Setting shape shape_list = cnn_net.get_shape().as_list() conv_reshaped.set_shape([None, shape_list[2], shape_list[1] * shape_list[3]]) return conv_reshaped
[docs]def deep_bidirectional_lstm(inputs: tf.Tensor, params: Params, summaries: bool=True) -> tf.Tensor: """ Recurrent part of the CRNN network. Uses a biderectional LSTM. :param inputs: output of ``deep_cnn`` :param params: parameters of the model :param summaries: flag to enable bias and weight histograms to be visualized in Tensorboard :return: Tuple : (tensor [width(time), batch, n_classes], raw transcription codes) """ # Prepare data shape to match `bidirectional_rnn` function requirements # Current data input shape: (batch_size, n_steps, n_input) "(batch, time, height)" list_n_hidden = [256, 256] with tf.name_scope('deep_bidirectional_lstm'): # Forward direction cells fw_cell_list = [BasicLSTMCell(nh, forget_bias=1.0) for nh in list_n_hidden] # Backward direction cells bw_cell_list = [BasicLSTMCell(nh, forget_bias=1.0) for nh in list_n_hidden] lstm_net, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(fw_cell_list, bw_cell_list, inputs, dtype=tf.float32 ) # Dropout layer lstm_net = tf.nn.dropout(lstm_net, keep_prob=params.keep_prob_dropout) with tf.variable_scope('Reshaping_rnn'): # shape = lstm_net.get_shape().as_list() # [batch, width, 2*n_hidden] shape = tf.shape(lstm_net) rnn_reshaped = tf.reshape(lstm_net, [shape[0]*shape[1], shape[2]]) # [batch x width, 2*n_hidden] with tf.variable_scope('fully_connected'): W = weightVar([list_n_hidden[-1]*2, params.alphabet.n_classes]) b = biasVar([params.alphabet.n_classes]) fc_out = tf.nn.bias_add(tf.matmul(rnn_reshaped, W), b) if summaries: weights = [var for var in tf.global_variables() if == 'deep_bidirectional_lstm/fully_connected/weights:0'][0] tf.summary.histogram('weights', weights) bias = [var for var in tf.global_variables() if == 'deep_bidirectional_lstm/fully_connected/bias:0'][0] tf.summary.histogram('bias', bias) lstm_out = tf.reshape(fc_out, [shape[0], shape[1], params.alphabet.n_classes], name='reshape_out') # [batch, width, n_classes] raw_pred = tf.argmax(tf.nn.softmax(lstm_out), axis=2, name='raw_prediction') # Swap batch and time axis lstm_out = tf.transpose(lstm_out, [1, 0, 2], name='transpose_time_major') # [width(time), batch, n_classes] return lstm_out, raw_pred
[docs]def crnn_fn(features, labels, mode, params): """ CRNN model definition for ``tf.Estimator``. Combines ``deep_cnn`` and ``deep_bidirectional_lstm`` to define the model and adds loss computation and CTC decoder. :param features: dictionary with keys : '`images`', '`images_widths`', '`filenames`' :param labels: string containing the transcriptions. Flattend (1D) array with encoded label (one code per character) :param mode: TRAIN, EVAL, PREDICT :param params: dictionary with keys: '`Params`', '`TrainingParams`' :return: """ parameters = params.get('Params') training_params = params.get('TrainingParams') assert isinstance(parameters, Params) assert isinstance(training_params, TrainingParams) if mode == tf.estimator.ModeKeys.TRAIN: parameters.keep_prob_dropout = 0.7 else: parameters.keep_prob_dropout = 1.0 conv = deep_cnn(features['images'], input_channels=parameters.input_channels, is_training=(mode == tf.estimator.ModeKeys.TRAIN), summaries=False) net_output, raw_pred = deep_bidirectional_lstm(conv, params=parameters, summaries=False) # Compute seq_len from image width n_pools = CONST.DIMENSION_REDUCTION_W_POOLING # 2x2 pooling in dimension W on layer 1 and 2 seq_len_inputs = tf.divide(features['images_widths'], n_pools, name='seq_len_input_op') - 1 predictions_dict = {'probs': tf.nn.softmax(net_output), # 'raw_predictions': raw_pred, } try: predictions_dict['filenames'] = features['filenames'] except KeyError: pass if not mode == tf.estimator.ModeKeys.PREDICT: # Alphabet and codes keys_alphabet_units = parameters.alphabet.alphabet_units values_alphabet_codes = # Convert string label to code label with tf.name_scope('str2code_conversion'): table_str2int = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer(keys_alphabet_units, values_alphabet_codes), -1) labels_splited = tf.string_split(labels, delimiter=parameters.string_split_delimiter) codes = table_str2int.lookup(labels_splited.values) sparse_code_target = tf.SparseTensor(labels_splited.indices, codes, labels_splited.dense_shape) seq_lengths_labels = tf.bincount(tf.cast(sparse_code_target.indices[:, 0], tf.int32), minlength=tf.shape(net_output)[1]) # Loss # ---- # >>> Cannot have longer labels than predictions -> error with tf.control_dependencies([tf.less_equal(sparse_code_target.dense_shape[1], tf.reduce_max(tf.cast(seq_len_inputs, tf.int64)))]): loss_ctc = tf.nn.ctc_loss(labels=sparse_code_target, inputs=net_output, sequence_length=tf.cast(seq_len_inputs, tf.int32), preprocess_collapse_repeated=False, ctc_merge_repeated=True, # ignore... = True : returns zero gradient in case it happens -> loss = NaN ignore_longer_outputs_than_inputs=True, time_major=True) loss_ctc = tf.reduce_mean(loss_ctc) loss_ctc = tf.Print(loss_ctc, [loss_ctc], message='* Loss : ') global_step = tf.train.get_or_create_global_step() # # Create an ExponentialMovingAverage object ema = tf.train.ExponentialMovingAverage(decay=0.99, num_updates=global_step, zero_debias=True) # Create the shadow variables, and add op to maintain moving averages maintain_averages_op = ema.apply([loss_ctc]) loss_ema = ema.average(loss_ctc) # Train op # -------- learning_rate = tf.train.exponential_decay(training_params.learning_rate, global_step, training_params.learning_decay_steps, training_params.learning_decay_rate, staircase=True) if training_params.optimizer == 'ada': optimizer = tf.train.AdadeltaOptimizer(learning_rate) elif training_params.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5) elif training_params.optimizer == 'rms': optimizer = tf.train.RMSPropOptimizer(learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) opt_op = optimizer.minimize(loss_ctc, global_step=global_step) with tf.control_dependencies(update_ops + [opt_op]): train_op = # Summaries # --------- tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('losses/ctc_loss', loss_ctc) else: loss_ctc, train_op = None, None if mode in [tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT, tf.estimator.ModeKeys.TRAIN]: with tf.name_scope('code2str_conversion'): keys_alphabet_codes = tf.cast(, tf.int64) values_alphabet_units = [c for c in parameters.alphabet.alphabet_units] table_int2str = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer(keys_alphabet_codes, values_alphabet_units), '?') # Output is 2 list of length NUM_BEAM_PATHS with tensors of shape [Batch, ...] sparse_code_pred, log_probability_ctc = tf.nn.ctc_beam_search_decoder( net_output, sequence_length=tf.cast(seq_len_inputs, tf.int32), merge_repeated=False, beam_width=100, top_paths=parameters.num_beam_paths) sequence_lengths_pred = tf.bincount(tf.cast(sparse_code_pred[0].indices[:, 0], tf.int32), minlength=tf.shape(net_output)[1]) pred_chars = table_int2str.lookup(sparse_code_pred[0]) predictions_dict['words'] = get_words_from_chars(pred_chars.values, sequence_lengths=sequence_lengths_pred) predictions_dict['codes'] = tf.sparse_to_dense(sparse_indices=sparse_code_pred[0].indices, output_shape=sparse_code_pred[0].dense_shape, sparse_values=sparse_code_pred[0].values) tf.summary.text('predicted_words', predictions_dict['words'][:10]) # Compute these values only when predicting, they're not useful during training/evaluation if mode == tf.estimator.ModeKeys.PREDICT: # Possible paths with tf.name_scope('get_best_paths_transcriptions'): sequence_lengths_pred = [tf.bincount(tf.cast(sp.indices[:, 0], tf.int32), minlength=tf.shape(net_output)[1]) for sp in sparse_code_pred] pred_chars = [table_int2str.lookup(sp) for sp in sparse_code_pred] predictions_dict['best_transcriptions'] = tf.stack( [get_words_from_chars(char.values, sequence_lengths=length) for char, length in zip(pred_chars, sequence_lengths_pred)] ) # Score : around 10.0 -> seems pretty sure, less than 5.0 bit unsure, some errors/challenging images predictions_dict['score'] = tf.subtract(log_probability_ctc[:, 0], log_probability_ctc[:, 1], name='score_computation') # Logprobs ctc decoding : predictions_dict['logprob_ctc'] = log_probability_ctc # Evaluation ops # -------------- if mode == tf.estimator.ModeKeys.EVAL: with tf.name_scope('evaluation'): CER = tf.metrics.mean(tf.edit_distance(sparse_code_pred[0], tf.cast(sparse_code_target, dtype=tf.int64)), name='CER') # Convert label codes to decoding alphabet to compare predicted and groundtrouth words target_chars = table_int2str.lookup(tf.cast(sparse_code_target, tf.int64)) target_words = get_words_from_chars(target_chars.values, seq_lengths_labels) accuracy = tf.metrics.accuracy(target_words, predictions_dict['words'], name='accuracy') eval_metric_ops = { 'eval/accuracy': accuracy, 'eval/CER': CER, } CER = tf.Print(CER, [CER], message='-- CER : ') accuracy = tf.Print(accuracy, [accuracy], message='-- Accuracy : ') else: eval_metric_ops = None export_outputs = {'predictions': tf.estimator.export.PredictOutput(predictions_dict)} return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions_dict, loss=loss_ctc, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, scaffold=tf.train.Scaffold() # scaffold=tf.train.Scaffold(init_fn=None) # Specify init_fn to restore from previous model )