Source code for tf_crnn.config

#!/usr/bin/env python
__author__ = 'solivr'
__license__ = "GPL"

import csv
import json
import os
import string
from functools import reduce
from glob import glob
from typing import List, Union
import pandas as pd


class CONST:
    DIMENSION_REDUCTION_W_POOLING = 2*2  # 2x2 pooling in dimension W on layer 1 and 2
    PREPROCESSING_FOLDER = 'preprocessed'


[docs]class Alphabet: """ Class for alphabet / symbols units. :ivar _blank_symbol: Blank symbol used for CTC :vartype _blank_symbol: str :ivar _alphabet_units: list of elements composing the alphabet. The units may be a single character or multiple characters. :vartype _alphabet_units: List[str] :ivar _codes: Each alphabet unit has a unique corresponding code. :vartype _codes: List[int] :ivar _nclasses: number of alphabet units. :vartype _nclasses: int """ def __init__(self, lookup_alphabet_file: str=None, blank_symbol: str='$'): self._blank_symbol = blank_symbol if lookup_alphabet_file: lookup_alphabet = self.load_lookup_from_json(lookup_alphabet_file) # Blank symbol must have the largest value if self._blank_symbol in lookup_alphabet.keys(): # TODO : check if blank symbol is the last one assert lookup_alphabet[self._blank_symbol] == max(lookup_alphabet.values()), \ "Blank symbol should have the largest code integer" lookup_alphabet[self._blank_symbol] = max(lookup_alphabet.values()) + 1 else: lookup_alphabet.update({self._blank_symbol: max(lookup_alphabet.values()) + 1}) self._alphabet_units = list(lookup_alphabet.keys()) self._codes = list(lookup_alphabet.values()) self._nclasses = len(self.codes) + 1 # n_classes should be + 1 of labels codes if 0 in self._codes: raise ValueError('0 code is in the lookup table, you should''nt use it.') self.lookup_int2str = dict(zip(self.codes, self.alphabet_units))
[docs] def check_input_file_alphabet(self, csv_filenames: List[str], discarded_chars: str=';|{}'.format(string.whitespace[1:]), csv_delimiter: str=";") -> None: """ Checks if labels of input files contains only characters that are in the Alphabet. :param csv_filenames: list of the csv filename :param discarded_chars: discarded characters :param csv_delimiter: character delimiting field in the csv file :return: """ assert isinstance(csv_filenames, list), 'csv_filenames argument is not a list' alphabet_set = set(self.alphabet_units) for filename in csv_filenames: input_chars_set = set() with open(filename, 'r', encoding='utf8') as f: csvreader = csv.reader(f, delimiter=csv_delimiter, escapechar='\\', quoting=0) for line in csvreader: input_chars_set.update(line[1]) # Discard all whitespaces except space ' ' for whitespace in discarded_chars: input_chars_set.discard(whitespace) extra_chars = input_chars_set - alphabet_set assert len(extra_chars) == 0, 'There are {} unknown chars in {} : {}'.format(len(extra_chars), filename, extra_chars)
[docs] @classmethod def map_lookup(cls, lookup_table: dict, unique_entry: bool = True) -> dict: """ Converts an existing lookup table with minimal range code ([1, len(lookup_table)-1]) and avoids multiple instances of the same code label (bijectivity) :param lookup_table: dictionary to be mapped {alphabet_unit : code label} :param unique_entry: If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective'), only True is implemented for now :return: a mapped dictionary """ # Create tuple (alphabet unit, code) tuple_char_code = list(zip(list(lookup_table.keys()), list(lookup_table.values()))) # Sort by code tuple_char_code.sort(key=lambda x: x[1]) # If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective') if unique_entry: mapped_lookup = [[tp[0], i + 1] for i, tp in enumerate(tuple_char_code)] else: raise NotImplementedError # Todo return dict(mapped_lookup)
[docs] @classmethod def create_lookup_from_labels(cls, csv_files: List[str], export_lookup_filename: str, original_lookup_filename: str=None): """ Create a lookup dictionary for csv files containing labels. Exports a json file with the Alphabet. :param csv_files: list of files to get the labels from (should be of format path;label) :param export_lookup_filename: filename to export alphabet lookup dictionary :param original_lookup_filename: original lookup filename to update (optional) :return: """ if original_lookup_filename: with open(original_lookup_filename, 'r') as f: lookup = json.load(f) set_chars = set(list(lookup.keys())) else: set_chars = set(list(string.ascii_letters) + list(string.digits)) lookup = dict() for filename in csv_files: data = pd.read_csv(filename, sep=';', encoding='utf8', error_bad_lines=False, header=None, names=['path', 'transcription'], escapechar='\\') for index, row in data.iterrows(): set_chars.update(row.transcription.split('|')) # Update (key, values) of lookup table for el in set_chars: if el not in lookup.keys(): lookup[el] = max(lookup.values()) + 1 if lookup.values() else 0 lookup = cls.map_lookup(lookup) # Save new lookup with open(export_lookup_filename, 'w', encoding='utf8') as f: json.dump(lookup, f)
[docs] @classmethod def load_lookup_from_json(cls, json_filenames: Union[List[str], str]) -> dict: """ Load a lookup table from a json file to a dictionnary :param json_filenames: either a filename or a list of filenames :return: """ lookup = dict() if isinstance(json_filenames, list): for file in json_filenames: with open(file, 'r', encoding='utf8') as f: data_dict = json.load(f) lookup.update(data_dict) elif isinstance(json_filenames, str): with open(json_filenames, 'r', encoding='utf8') as f: lookup = json.load(f) return cls.map_lookup(lookup)
[docs] @classmethod def make_json_lookup_alphabet(cls, string_chars: str = None) -> dict: """ :param string_chars: for example string.ascii_letters, string.digits :return: """ lookup = dict() if string_chars: # Add characters to lookup table lookup.update({char: ord(char) for char in string_chars}) return cls.map_lookup(lookup)
@property def n_classes(self): return self._nclasses @property def blank_symbol(self): return self._blank_symbol @property def codes(self): return self._codes @property def alphabet_units(self): return self._alphabet_units
[docs]class Params: """ Class for parameters of the model and the experiment :ivar input_shape: input shape of the image to batch (this is the shape after data augmentation). The original will either be resized or pad depending on its original size :vartype input_shape: Tuple[int, int] :ivar input_channels: number of color channels for input image (default: 1) :vartype input_channels: int :ivar cnn_features_list: a list of length `n_layers` containing the number of features for each convolutionl layer (default: [16, 32, 64, 96, 128]) :vartype cnn_features_list: List(int) :ivar cnn_kernel_size: a list of length `n_layers` containing the size of the kernel for each convolutionl layer (default: [3, 3, 3, 3, 3]) :vartype cnn_kernel_size: List(int) :ivar cnn_stride_size: a list of length `n_layers` containing the stride size each convolutionl layer (default: [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) :vartype cnn_stride_size: List((int, int)) :ivar cnn_pool_size: a list of length `n_layers` containing the pool size each MaxPool layer default: ([(2, 2), (2, 2), (2, 2), (2, 2), (1, 1)]) :vartype cnn_pool_size: List((int, int)) :ivar cnn_batch_norm: a list of length `n_layers` containing a bool that indicated wether or not to use batch normalization (default: [False, False, False, False, False]) :vartype cnn_batch_norm: List(bool) :ivar rnn_units: a list containing the number of units per rnn layer (default: 256) :vartype rnn_units: List(int) :ivar num_beam_paths: number of paths (transcriptions) to return for ctc beam search (only used when predicting) :vartype num_beam_paths: int :ivar csv_delimiter: character to delimit csv input files (default: ';') :vartype csv_delimiter: str :ivar string_split_delimiter: character that delimits each alphabet unit in the labels (default: '|') :vartype string_split_delimiter: str :ivar csv_files_train: csv filename which contains the (path;label) of each training sample :vartype csv_files_train: str :ivar csv_files_eval: csv filename which contains the (path;label) of each eval sample :vartype csv_files_eval: str :ivar lookup_alphabet_file: json file that contains the mapping alphabet units <-> codes :vartype lookup_alphabet_file: str :ivar blank_symbol: symbol for to be considered as blank by the CTC decoder (default: '$') :vartype blank_symbol: str :ivar max_chars_per_string: maximum characters per sample (to avoid CTC decoder errors) (default: 75) :vartype max_chars_per_string: int :ivar data_augmentation: if True augments data on the fly (default: true) :vartype data_augmentation: bool :ivar data_augmentation_max_rotation: max permitted roation to apply to image during training in radians (default: 0.005) :vartype data_augmentation_max_rotation: float :ivar data_augmentation_max_slant: maximum angle for slant augmentation (default: 0.7) :vartype data_augmentation_max_slant: float :ivar n_epochs: numbers of epochs to run the training (default: 50) :vartype n_epochs: int :ivar train_batch_size: batch size during training (default: 64) :vartype train_batch_size: int :ivar eval_batch_size: batch size during evaluation (default: 128) :vartype eval_batch_size: int :ivar learning_rate: initial learning rate (default: 1e-4) :vartype learning_rate: float :ivar evaluate_every_epoch: evaluate every 'evaluate_every_epoch' epoch (default: 5) :vartype evaluate_every_epoch: int :ivar save_interval: save the model every 'save_interval' epoch (default: 20) :vartype save_interval: int :ivar optimizer: which optimizer to use ('adam', 'rms', 'ada') (default: 'adam') :vartype optimizer: str :ivar output_model_dir: output directory where the model will be saved and exported :vartype output_model_dir: str :ivar restore_model: boolean to continue training with saved weights (default: False) :vartype restore_model: bool """ def __init__(self, **kwargs): # model params self.input_shape = kwargs.get('input_shape', (96, 1400)) self.input_channels = kwargs.get('input_channels', 1) self.cnn_features_list = kwargs.get('cnn_features_list', [16, 32, 64, 96, 128]) self.cnn_kernel_size = kwargs.get('cnn_kernel_size', [3, 3, 3, 3, 3]) self.cnn_stride_size = kwargs.get('cnn_stride_size', [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) self.cnn_pool_size = kwargs.get('cnn_pool_size', [(2, 2), (2, 2), (2, 2), (2, 2), (1, 1)]) self.cnn_batch_norm = kwargs.get('cnn_batch_norm', [False, False, False, False, False]) self.rnn_units = kwargs.get('rnn_units', [256, 256]) # self._keep_prob_dropout = kwargs.get('keep_prob_dropout', 0.5) self.num_beam_paths = kwargs.get('num_beam_paths', 1) # csv params self.csv_delimiter = kwargs.get('csv_delimiter', ';') self.string_split_delimiter = kwargs.get('string_split_delimiter', '|') self.csv_files_train = kwargs.get('csv_files_train') self.csv_files_eval = kwargs.get('csv_files_eval') # alphabet params self.blank_symbol = kwargs.get('blank_symbol', '$') self.max_chars_per_string = kwargs.get('max_chars_per_string', 75) self.lookup_alphabet_file = kwargs.get('lookup_alphabet_file') # data augmentation params self.data_augmentation = kwargs.get('data_augmentation', True), self.data_augmentation_max_rotation = kwargs.get('data_augmentation_max_rotation', 0.005) self.data_augmentation_max_slant = kwargs.get('data_augmentation_max_slant', 0.7) # training params self.n_epochs = kwargs.get('n_epochs', 50) self.train_batch_size = kwargs.get('train_batch_size', 64) self.eval_batch_size = kwargs.get('eval_batch_size', 128) self.learning_rate = kwargs.get('learning_rate', 1e-4) self.optimizer = kwargs.get('optimizer', 'adam') self.output_model_dir = kwargs.get('output_model_dir', '') self.evaluate_every_epoch = kwargs.get('evaluate_every_epoch', 5) self.save_interval = kwargs.get('save_interval', 20) self.restore_model = kwargs.get('restore_model', False) self._assign_alphabet() cnn_params = zip(self.cnn_pool_size, self.cnn_stride_size) self.downscale_factor = reduce(lambda i, j: i * j, map(lambda k: k[0][1] * k[1][1], cnn_params)) # TODO add additional checks for the architecture assert len(self.cnn_features_list) == len(self.cnn_kernel_size) == len(self.cnn_stride_size) \ == len(self.cnn_pool_size) == len(self.cnn_batch_norm), \ "Length of parameters of model are not the same, check that all the layers parameters have the same length." max_input_width = (self.max_chars_per_string + 1) * self.downscale_factor assert max_input_width <= self.input_shape[1], "Maximum length of labels is {}, input width should be greater or " \ "equal to {} but is {}".format(self.max_chars_per_string, max_input_width, self.input_shape[1]) assert self.optimizer in ['adam', 'rms', 'ada'], 'Unknown optimizer {}'.format(self.optimizer) if os.path.isdir(self.output_model_dir): print('WARNING : The output directory {} already exists.'.format(self.output_model_dir))
[docs] def show_experiment_params(self) -> dict: """ Returns a dictionary with the variables of the class. :return: """ return vars(self)
def _assign_alphabet(self): self.alphabet = Alphabet(lookup_alphabet_file=self.lookup_alphabet_file, blank_symbol=self.blank_symbol) # @property # def keep_prob_dropout(self): # return self._keep_prob_dropout # # @keep_prob_dropout.setter # def keep_prob_dropout(self, value): # assert (0.0 < value <= 1.0), 'Must be 0.0 < value <= 1.0' # self._keep_prob_dropout = value
[docs] def to_dict(self) -> dict: """ Returns the parameters as a dictionary :return: """ new_dict = self.__dict__.copy() del new_dict['alphabet'] del new_dict['downscale_factor'] return new_dict
[docs] @classmethod def from_json_file(cls, json_file: str): """ Given a json file, creates a ``Params`` object. :param json_file: path to the json file :return: ``Params`` object """ with open(json_file, 'r') as file: config = json.load(file) return cls(**config)
[docs]def import_params_from_json(model_directory: str=None, json_filename: str=None) -> dict: """ Read the exported json file with parameters of the experiment. :param model_directory: Direcoty where the odel was exported :param json_filename: filename of the file :return: a dictionary containing the parameters of the experiment """ assert not all(p is None for p in [model_directory, json_filename]), 'One argument at least should not be None' if model_directory: # Import parameters from the json file try: json_filename = glob(os.path.join(model_directory, 'model_params*.json'))[-1] except IndexError: print('No json found in dir {}'.format(model_directory)) raise FileNotFoundError else: if not os.path.isfile(json_filename): print('No json found with filename {}'.format(json_filename)) raise FileNotFoundError with open(json_filename, 'r') as data_json: params_json = json.load(data_json) # Remove 'private' keys keys = list(params_json.keys()) for key in keys: if key[0] == '_': params_json.pop(key) return params_json