This is a blog series on full stack ML application with PyTorch. This post talks about the main process and learnings for re-implementing labs 3-5 in PT. These 3 labs implement a model for line recognition (given a line image, output a string), and related experiment management.

TLDR:
1. CTCloss is tricky to debug
2. Small sample overfitting also requires hyperparam tuning
3. When Adam does not work, try RMSprop

NN-wise

The DL work is more involving in these labs: two loss functions are explored: NLLLoss-based model and CTCLoss-based model (paper). Per NN implementation, three major networks are explored: vanilla CNN, LSTM-based, and CRNN. I have struggled quite a bit with CTC-based training and prediction.

CTC related util function:
CTC-based model depends on a sliding window to create patches for classification, after which the CTCLoss finds the best path given this series of classifications and compare it with the label. To swtich to PT, there are a few util functions to change: format_batch_ctc, slide_window, ctc_decode.

from_batch_ctc()
This is a helper function to include loss in NN output, since in Keras the .fit() API is rigid. By contrast, in PT you can simply include that in your .forward() function, thus here we can safely remove from_batch_ctc(). This reminds the fact that in Keras you need to conform to the .fit() schema and put anything else as Callback or a separate model to be combined into KerasModel, while in PT you can do whatever you like and there’s no schema but just ‘best practice’.

slide_window()
The CTCLoss requires a sequence of samples. Since our input is an image of line writing, we need a sliding window to create patches. In Keras it was

tf.extract_image_patches(image, kernel, strides, [1, 1, 1, 1], 'VALID')

while in PT it is

images.unfold(3, window_width, window_stride) # (b,c,h,p,window)

Althougth Keras API still looks rigid, the fact that such util functions are wrapped in a special Lambda layer seems more clear to me for util separation.

ctc_decode()
The raw outputs of the NN will have repetitive tokens e.g. [0, 1, 1, empty, 1, 2, 2, 3, 3,] so we need to decode them into normal sequence e.g. [0, 1, empty, 1, 2, 3]. Keras has prebuilt function tf.sparse_to_dense but also has additional dependencies like padding. By contrast, you need to reimplement this function for PT (very straightforward, though).

Roadblocks:
After implementation, the training always converged to predicting all same blank_id, even when trying to overfit a sample of 100. Since I was able to get the original pipeline (the keras version) running, I thought I had a bug, which in the end I still cannot find any. I encountered a reported issue with PT’s implementation of CTCLoss(ref), and seemingly there’s implicit bug (1 and 2).

I decided to switch to non-PT’s CTCLoss and unfortunately entered the dependency hell: the commonly used warp-ctc requires PT 4.0 and another implementation pytorch-baidu-ctc has issues with PT1.1 (while claiming it’s tested) (ref), while I was using PT1.1.

After a lot of hacks on local compiling or moving C-bindings, I have to give up and revert to PT1.0, which for me required GPU driver update due to new CUDA dependency. All in all, it was painful. With baidu-ctc-loss, I can get training running and loss decreasing, this is the first time I can make sure my re-implementation is correct. After this I decided to switch back to PT’s CTCLoss and let it run overnight (typically this traning takes ~30min to converge to something good enough). It turned out training properly as well, just way more slowly.

I re-examined the difference of my PT implementation vs the original on every detail and realized I was using Adam as optimizer instead of RMSProp that was used originally. After switching to RMSProp, it now takes 30min to train! All in all, with everything else untouched, both baidu-CTC vs PT’s native CTCLoss, and Adam vs RMSProp have significant impacts on training. And somehow unfortunately PT's CTCLoss has bad chemistry with Adam. This is really some dark magic of deep learning.

Annoyance:
1. PT's implicit integer coersion when dividing tensor with int8:

if x.dtype == torch.uint8: 
  # NOTE should tensor.to(float) before division otherwise coerced to 0
  x = (x.to(torch.float32) / 255)

2. Comparison between different data types:
When leaving the tensor realm and making user-facing interactions (e.g. decode prediction to character sequence), make sure to unwrap your tensor class e.g. use tensor.item(). When implementing ctc_decode(), I had one time always getting empty string during decoding. It turned out that my vocabulary.get(char_idx) never gets a hit thus not able to output a non-empty string. This is because in my buggy implementation, char_idx - supposed to be an int - is actually a tensor (e.g. tensor(1)), but this bug is hidden since tensor(1)==1 is True and thus it was satifying some condition. It’s trivial in retrospect, but it was extremely confusing because ctc_decode() only gets called during training and evaluation (i.e. not used standalone), which means your first response for such a bug is that you’ve messed up some training logic.

Data-wise

Since we have multiple datasets, consolidating its API and the expected output is very important. For example, the data labels are one-hot-encoded in the original set up to conform to the Keras' loss function. In PT we will need int/scalar. Thus as a quick hack, I retransform the OHE back to scalar during DataSequence, so that it's backward compatible with upstream util functions, which might expect OHE-shape of data.

Ops-wise

The only task is to implement prepare_experiment.py, which creates sets of config files to run experiment. Easy.

Future Direction

The struggle with CTCloss is not settled yet: I had an mysterious observation where the loss of a perfect prediction is higher than that of a all blank_id. I've filed a question in PT's forum which remained unanswered, so prob a good research topic.

Comments

comments powered by Disqus

Published

Category

research

Tags