1. Overview

Reader-Translator-Generator (RTG) is a Neural Machine Translation toolkit based on pytorch.

1.1. Features

  • Reproducible experiments: one conf.yml that has everything — data paths, params, and hyper params — required to reproduce experiments.

  • Pre-processing options: sentencepiece or nlcodec (or add your own)

    • word/char/bpe etc types

    • shared vocabulary, separate vocabulary

      • one-way, two-way, three-way tied embeddings

  • Transformer model from "Attention is all you need"

    • Automatically detects and parallelizes across multi GPUs

      • Lot of varieties of transformer: width varying, skip transformer etc configurable from YAML files

      • RNN based Encoder-Decoder with Attention. (No longer using it, but it’s available for experimentation)

  • Language Modeling: RNN, Transformer

  • And more …​

    • Easy and interpretable code (for those who read code as much as papers)

    • Object Orientated Design. (Not too many levels of functions and function factories like Tensor2Tensor)

    • Experiments and reproducibility are main focus. To control an experiment you edit an YAML file that is inside the experiment directory.

    • Where ever possible, prefer convention-over-configuration. Have a look at this experiment directory structure (below).

1.2. Google Colab Example

Use this Google Colab notebook for learning how to train your NMT model with RTG: colab.research.google.com/drive/198KbkUcCGXJXnWiM7IyEiO1Mq2hdVq8T?usp=sharing

1.3. Setup

PyPI version

pip install rtg

1.3.1. Development Setup:

This mode of setup is required only if you are developing (i.e. modifying RTG code). If you are planning to use RTG without modifying source code, then pip install rtg should be all you need.

While most users are Add the root of this repo to PYTHONPATH or install it via pip --editable

There are two versions of code:

Both rtg and rtg-in have the same code on their master branches. rtg has stable code base and meant to be used by anyone, so it is recommended for the new users. rtg-in is internal to ISI NLP with some unfinished/work-in progress ideas (maybe unpublished), with issues and pull-requests by members of USC ISI team, and often less stable. We sync both code bases often (sync-xt.sh at the root of the repo). If you like to collaborate with us and/or to get access to rtg-in, email TG or Jonathan May.

git clone https://github.com/isi-nlp/rtg.git
cd rtg                # go to the code

conda create -n rtg python=3.7   # creates a conda env named rtg
conda activate rtg       # activate it


pip install --editable .
# The requirements are in setup.py; you may customize it if you wish

export PYTHONPATH=$PWD # or add it to PYTHONPATH

1.3.2. Requirements

The required libraries are automatically installed by pip, so manual installation is not required. We are listing the requirements here for informative purposes only.
To view or modify the version numbers of libraries, please go to setup.py at the root of this project.

The following libraries are used:

Table 1. Table Summary of CLI tools

Library

Purpose

torch

deep learning library

tensorboard

logging and visualizing training and validation losses

sacrebleu

BLEU scorer

sacremoses

tokenization and detokenization

tqdm

Progress bar

ruamel.yaml

configuration management

sentencepiece

(optional) vocabulary creation using word, char, BPE

nlcodec

(optional) similar to sentencepiece, but easily customizable; scales to big datasets using pyspark, offers efficient storage of encoded parallel data

flask, jinja

(optional) HTTP API and web interface for serving the models

pyspark

(optional) parallelized data preparation (using nlcodec) for massive datasets.

Thanks to all the awesome developers of these above tools.

1.4. Usage

Refer to scripts/rtg-pipeline.sh bash script and examples/transformer.base.yml file for specific examples.

The pipeline takes source (.src) and target (.tgt) files. The sources are in one language and the targets in another. At a minimum, supply a training source, training target, validation source, and validation target. It is best to use .tok files for training. (.tok means tokenized.)

Example of training and running a mdoel:

# if you wish to disable gpu, unset
# export CUDA_VISIBLE_DEVICES=

python -m rtg.pipeline experiments/sample-exp/

# or use CLI tool installed by pip install
rtg-pipe experiments/sample-exp/

# or use shell script, edit it to your needs, to submit to Slurm/SGE
scripts/rtg-pipeline.sh -d experiments/sample-exp/ -c experiments/sample-exp/conf.yml

# Then to use the model to translate something:
# (VERY poor translation due to small training data)
echo "Chacun voit midi à sa porte." | python -m rtg.decode experiments/sample-exp/

The 001-tfm directory that hosts an experiment looks like this:

001-tfm
├── _PREPARED    <-- Flag file indicating experiment is prepared
├── _TRAINED     <-- Flag file indicating experiment is trained
├── conf.yml     <-- Where all the params and hyper params are! You should look into this
├── data
│   ├── samples.tsv.gz          <-- samples to log after each check point during training
│   ├── sentpiece.shared.model  <-- as the name says, sentence piece model, shared
│   ├── sentpiece.shared.vocab  <-- as the name says
│   ├── train.db                <-- all the prepared trainig data in a sqlite db
│   └── valid.tsv.gz            <-- and the validation data
├── githead       <-- whats was the git HEAD hash this experiment was started?
├── job.sh.bak    <-- job script used to submit this to grid. Just in case
├── models        <-- All checkpoints go inside this
│   ├── model_400_5.265583_4.977106.pkl
│   ├── model_800_4.478784_4.606745.pkl
│   ├── ...
│   └── scores.tsv <-- train and validation losses. incase you dont want to see tensorboard
├── rtg.log   <-- the python logs are redirected here
├── rtg.zip   <-- the source code used to run. just `export PYTHONPATH=rtg.zip` to
├── scripts -> /Users/tg/work/me/rtg/scripts  <-- link to some perl scripts for detok+BLEU
├── tensorboard    <-- Tensorboard stuff for visualizations
│   ├── events.out.tfevents.1552850552.hackb0x2
│   └── ....
└── test_step2000_beam4_ens5   <-- Tests after the end of training, BLEU scores
    ├── valid.ref -> /Users/tg/work/me/rtg/data/valid.ref
    ├── valid.src -> /Users/tg/work/me/rtg/data/valid.src
    ├── valid.out.tsv
    ├── valid.out.tsv.detok.tc.bleu
    └── valid.out.tsv.detok.lc.bleu

1.5. Credits / Thanks

2. RTG conf.yml File

The key component of RTG toolkit is a conf.yml. As the name suggest - it is a YAML file containing configuration of experiment. Before we try to understand what goes into a configuration file, let us review the high level entities:

  • Experiment - the top level entity that wraps everything below, for the sake of reproducibility.

  • Data Preparation - NLP datasets require preparation of textual data. Typically, creation of vocabulary to map text into sequence of integers. Here we can specify type of encoding scheme such as BPE/char/words, and vocabulary size.

  • Model - model is neural net for NMT or LM tasks. Here we

  • Optimizer - Optimizer and optimization criteria

  • Trainer - training steps, batch size etc

  • Tester [Optional] — testing to do post training

    • Tuner [Optional] - to search for beam size, length penalty etc

    • Decoder - the Beam decoder parameters, maybe overwritten by Tuner

    • Suite - a set of source and reference file pairs, for computing BLEU scores

2.1. Config Example:

conf.yml
model_args: # model construction args
  ff_size: 2048
  hid_size: 512
  n_heads: 8
  attn_dropout: 0.1  # Use lower dropout rates for attention because it masks an entire timestep
  dropout: 0.2
  enc_layers: 6
  dec_layers: 6
  src_vocab: 8000
  tgt_vocab: 8000
  tied_emb: three-way  # choices: null, one-way, two-way, three-way
  # self_attn_rel_pos: 8  # enable relative pos self attention with window=8
model_type: tfmnmt  # model type. tfmnmt is the transformer NMT model
optimizer:
  name: adam
  args:
    betas:
    - 0.9
    - 0.98
    eps: 1.0e-09
    lr: 0.1

schedule:
  name: noam
  args:
    constant: 2
    warmup: 8000
    model_dim: 512

criterion:
  name: smooth_kld    #options "cross_entropy", "smooth_kld", "binary_cross_entropy", "triplet_loss"
  args:
    label_smoothing: 0.1

prep: # data preparation
  max_types: 8000  # maximum number of types in vocab ; if shared_vocab=false, set max_src_types and max_tgt_types separately
  pieces: bpe   # choices: bpe, char, word, unigram  from google/sentencepiece
  shared_vocab: true  # true means same vocab for src and tgt, false means different vocabs
  src_len: 256   # longer sentences, decision is made as per 'truncate={true,false}'
  tgt_len: 256
  truncate: true  # what to do with long sentences: if true truncate at src_len or tgt_len; if false filter away
  train_src: wmt_data/data/de-en/europarl-v9.de-en.de.tok   # training data
  train_tgt: wmt_data/data/de-en/europarl-v9.de-en.en.tok
  valid_src: wmt_data/data/dev/newstest2013.de.tok
  valid_tgt: wmt_data/data/dev/newstest2013.en.tok
  valid_tgt_raw: wmt_data/data/dev/newstest2013.en  # unmodified; required for BLEU
tester:
  decoder:
   beam_size: 4
   batch_size: 18000   # effective size = batch_size/beam_size
  suit:  # suit of tests to run after the training
    newstest2013:  # name of test and list of src.tok, ref file (ref should be unmodified)
      - wmt_data/data/dev/newstest2013.de.tok
      - wmt_data/data/dev/newstest2013.en
    newstest2014:  # name of test and list of src.tok, ref file (ref should be unmodified)
      - wmt_data/data/dev/newstest2014-deen-src.de.tok
      - wmt_data/data/dev/newstest2014-deen-ref.en
trainer:
  init_args:
    chunk_size: 10   # generation in chunks of time steps to reduce memory consumption
    grad_accum: 1     # How many batches to accumulate gradients
  batch_size: 4200   # not exceeding these many tokens (including paddings)
  check_point: 1000  # how often to checkpoint?
  keep_models: 10   # how many checkpoints to keep on disk (small enough to save disk, large enough for checkpt averaging
  steps: 200000      # how many steps to train; if early_stop is enabled, this is max steps
  keep_in_mem: true   # keep training data in memory
updated_at: '2019-03-09T21:15:33.707183'  # automatically updated by system
seed: 12345  # fix the manual seed of pytorch + cuda + numpy + python_stdlib RNGs. Remove/comment this to disable

2.2. Config options

Table 2. Summary of component choices
Component Choices

model

tfmnmt, rnnmt, rnnlm, tfmlm, skptfmnmt, wvtfmnmt, wvskptfmnmt, tfmextembmt, robertamt, mtfmnmt, hybridmt, CBOW, tfmcls

optimizer

adam, sgd, adagrad, adam_w, adadelta, sparse_adam

schedule

noam, inverse_sqrt

criterion

sparse_cross_entropy, kl_divergence, focal_loss, binary_cross_entropy, smooth_kld, triplet_loss, smooth_kld_and_triplet_loss, dice_loss, squared_error

2.2.1. schedule options

  1. noam with args:

    • warmup

    • constant

    • model_dim

  2. inverse_sqrt with args:

    • warmup

    • peak_lr

2.2.2. criterion options

  • smooth_kld (recommended; used since the first version of transformer)

    • label_smoothing: float : [0, 1] : optional: default=0.1

Table 3. Args to smooth_kld

Name

Type

Range/Choices

Required

Default

label_smoothing

float

[0.0, 1.0)

Optional

0.1

  • sparse_cross_entropy

Table 4. Args to sparse_cross_entropy
Name Type Range/Choices Required Default Comment

weight

str

{inv_freq, inv_sqrt_freq, inv_log_freq}

Optional

None ⇒ disable weighing

weight_calm_time

int

[0, )

Optional

0 ⇒ disable calming;

Applicable when weight is enabled

  • kl_divergence (re-implementation of smooth_kld with some extra features)

Table 5. Args to kl_divergence
Name Type Range/Choices Required Default

label_smoothing

float

[0.0, 1.0)

Optional

0.0 ⇒ disable label smoothing

weight

str

{inv_freq, inv_sqrt_freq, inv_log_freq}

Optional

None ⇒ disable weighing

weight_calm_time

int

[0, )

Optional

0 ⇒ disable calming ⇒ weights applicable from step 0

  • focal_loss .Args to focal_loss

Name Type Range/Choices Required Default

gamma

float

[0.0, )

Optional

0.0 ⇒ disable ⇒ cross entropy

weight_calm_time

int

[0, )

Optional

0 ⇒ disable calming ⇒ weights applicable from step 0

  • Experimental loss functions:

    • dice_loss

    • binary_cross_entropy

    • triplet_loss

    • squared_error

2.3. Early stop

Add the below piece of config to trainer to enable early stop on convergence.

trainer:
  ....           # other args
  steps: 100000      # steps is treated as max steps
  checkpoint: 1000   # validate every these many steps
  early_stop:       # remove this block to disable
    enabled: true   # or, alternatively flip this to disable;
    by: loss        # stop by validation loss (default); TODO: add BLEU
    patience: 5     # how many validations to wait, to be sure of stopping; each validation is per check_point steps
    min_steps: 8000  # minimum steps to wait before test for early stop;
    signi_round: 3   # significant in 'by' value, used as round(value, signi_round).
                     # e.g. round(1/3, 3) = 0.333; round(100/3, 0) = 33; round(100/3, -1) = 30.0

2.4. Optimizer

By default, we use the ADAM optimizer from Adam: A Method for Stochastic Optimization. It is also possible to use ADAMW from Decoupled Weight Decay Regularization, since weight decay is different in optimizers with variable step sizes.

An alternative optimizer may look like:

optimizer:
  name: adamw
  args:
    betas:
    - 0.9
    - 0.98
    eps: 1.0e-09
    lr: 0.0005   # this doesnt matter, see "schedule"
    weight_decay: 1e-3

schedule:
  name: inverse_sqrt
  args:
    warmup: 4000
    peak_lr: 0.0005

2.5. Fine Tuning

We define fine tuning as the act of changing the training data at certain time step in the training process. To enable this feature, we need to do following.

Step1. specify, finetune_src finetune_tgt in the prep block as follows

prep: # data preparation
  ....
  train_src: wmt_data/data/de-en/europarl-v9.de-en.de.tok   # training data
  train_tgt: wmt_data/data/de-en/europarl-v9.de-en.en.tok
  finetune_src: wmt_data/data/de-en/finetune.de-en.de.tok   # Finetuning data
  finetune_tgt: wmt_data/data/de-en/finetune.de-en.en.tok
  valid_src: wmt_data/data/dev/newstest2013.de.tok
  valid_tgt: wmt_data/data/dev/newstest2013.en.tok

Step2, Inform the Trainer to continue training, edit the trainer block with finetune_steps.

trainer:
  batch_size: 12000        # training batch size
  steps: 200000           # how many steps to train
  finetune_steps: 300000 # fine tuning steps.
  finetune_batch_size: 1024  # fine tuning batch_size; optional; default is training batach_size

This makes the trainer use train_{src,tgt} for 0 - 200k steps, followed by finetune_{src,tgt} for 200k-300k steps. Note that finetune_steps > steps .

2.6. Parent-Child Transfer

To initialize from another compatible model as parent, add parent: specification to conf.yml as shown below:

model_type: tfmnmt
model_args:
  # will be inherited from parent  ; see parent.mode.args: true
parent:
  experiment: <path/to/experiment/dir>
  vocab:
    shared: shared       # for reusing the shared vocab
    #src: src            # for separate vocabs
    #tgt: tgt
  shrink: true        # shrink vocabularies and embeddings to child data
                      # specified in train_{src,tgt} and mono_{src,tgt}
  model:
    args: true          # update/overwrite the model_args of child with the parent
    ensemble: 5         # how many checkpoints of parent to ensemble, to obtain initial state
# ... rest of the config such as prep, trainer etc

2.7. Freezing some parts of model

Frozen weights associated to parts of network means the weights remain unmodified during the course of the training. It is a useful feature when the model weights are initialized from a well trained parent model. WKT Optimizer is the one that modifies model’s parameters according to their gradients. Therefore, to freeze the weights implies excluding the weights from optimizer. Or alternatively, explicitly mention the parts of the model needs to be trained (i.e. updated by optimizer).

Here is an example — comment or remove the parts that you wish to freeze in the below 6 layer network.

optimizer:
  name: adam
  args:
    ....# the usual args for optimizer
  trainable:  # trainable parameter
    include: # only include these and exclude everything else not listed here
    - src_embed
    - tgt_embed
    - generator
    - 'encoder:0,1,2,3,4,5'  # the numbers are layer indices starting from 0
    - 'decoder:0,1,2,3,4,5'  # the numbers are layer indices starting from 0

TODO: add support for exclude logic i.e., include everything else except the mentioned.

This feature is supported only in AbstractTransformerNMT and all of its children. If you are adding a new NMTModel or customising this feature, please override get_trainable_parameters(self, include, exclude) function to support this feature.

2.8. Sharing Data between Experiments

In the new experiment config, add same_data to reference parent experiment from which the data should be reused for training and validation. Note that this uses the same vocabulary as parent. The child experiment creates a symbolic link to parent experiments data (instead of copying, to reduce the disk space).

Alternatively, you may use rtg-fork --data to fork an experiment with same data, where the forked experiment will have symbolic link to parent’s data.

prep:
  same_data: path/to/prior/experiment_dir

2.9. Vocabulary Preprocessing

Google’s sentencepiece is an awesome lib for preprocessing the text datasets. We’ve used sentencepiece’s python API since day-1 of RTG and it is the default library. However, since the core sentencepiece is written in C, it was hard to modify to explore some new ideas on BPE (without knowing C). So, we reimplemented BPE in pure python, with advanced datastructures such as linked-lists, prefix tries and dirty-maxheap to match the speed. Our reimplementation is named as NLCodec. NLCoded can be enabled as:

prep:
  ....
  codec_lib: nlcodec  # default is sentpiece

2.9.1. Vocabulary Types

Both sentpiece or nlcodec support pieces= bpe, char, word.

prep:
  ....
  codec_lib: nlcodec  # other option: sentpiece
  pieces: bpe         # other options: char, word

As of now, only sentpiece supports pieces=unigram. For classification experiments, nlcodec supports pieces=class

2.9.2. Character coverage

For bpe and char vocabulary types, a useful trick is to exclude low frequency character and mark them as UNK’s. Usually expressed as percentage of character coverage in training corpus. Sentencepiece’s default (when we last checked) is 99.95% ie 0.9995. Here is how to set this for eg to 99.99% i.e. 0.9999 in nlcodec

prep:
  ....
  codec_lib: nlcodec      # other option: sentpiece
  pieces: bpe             # other options: char, word
  char_coverage: 0.9999

2.10. Sub-Word Regularization

When using codec_lib: nlcodec and pieces: bpe, you have the option to add sub-word regularization to your training. Normally, text is split into the fewest tokens necessary to represent the sequence (greedy split). By occasionally splitting some tokens into its constituents (suboptimal split), we can represent the same sequence many ways. This allows us to leverage less data more effectively.

trainer:
  ....
  split_ratio: 0.1        # 10% chance to suboptimally split (recursive)
  dynamic_epoch: true     # Recompute splits for each epoch

3. Avoiding Out-of-Memory

Out-of-memory is pretty common, and we have worked out ways to avoid that situation as much as possible.

3.1. Trainer Memory

Let’s visualize the total required memory for training a model in the order of a 4D tensor: [ ModelDim x Batch x SequenceLength x Vocabulary]

So, we are left with Batch x SequenceLength as two dims that we can manipulate.

For SequenceLength, set trainer.init_args.chunk_size to a smaller value to break down whole sequence into smaller chunks. This operation does not affect gradients, but affects training time. Smaller chunk_size ⇒ less memory, but it also means more chunks ⇒ more time. Also note that the prep.src_len and prep.tgt_len allows you to decide maximum length of source and target sequences. When combined that with prep.truncate=True, all longer sequences will be truncated, or prep.truncate=False causes the longer sequences to be dropped.

Regarding Batch, there are some things you can do.

  1. If you have GPUs with larger memory, use them. For example, V100 with 32GB is much better than 1080 Ti with 11GB.

  2. If you dont have larger GPU, but you have many smaller GPUs, use many them by setting CUDA_VISIBLE_DEVICES variable to comma separated list of GPU IDs. The built in DataParallel module divides batches into multiple GPUs ⇒ reduces total memory needed on each GPU.

  3. If you dont have multiple GPUs, use trainer.init_args.grad_accum. eg. if you set grad_accum=2, the effective batch_size is 2 * batch_size.

In summary, to make best out of your GPUs, adjust trainer.init_args.chunk_size, trainer.init_args.grad_accum, and trainer.batch_size. I suggest using gpustat -i 0.5, look at the GPU RAM usage and see if you need to increase or decrease some parameters.

Regarding the CPU RAM, we usually need as much as a single GPU RAM. But if you have a plenty of it, please enable trainer.keep_in_mem=True to reduce disk IO. This keep_in_mem parameter informs the trainer to load training data once and hold it in CPU RAM during the course of training.

3.2. Decoder Memory

Since beam decoder is used, let’s visualize memory as [Batch x Beams x Vocabulary x SequenceLength]

  • tester.decoder.beam_size : Number of beams to be used. You may reduce it, e.g. beam_size=4 if often a good value.

  • tester.decoder.batch_size for 1 beam. internally, it calculates, effective = batch_size/beam_size

  • tester.decoder.max_len is a relative length. It decides how long the target sequence can grow in relation to source length. For example, max_len=50 ⇒ len(src) + 50

rtg-decode has --max-src-len argument which can be used to hard limit the max length of source sentences. --max-src-len can be degrade test performance since it drops out words. Right thing to do for long sequences will be to split long sentences in input and merge the outputs after decoding.

4. Migration

4.1. v0.5.0 or earlier to v0.6.0

The optimizer block got a big update in v0.6.0, as a result it is not backward compatible.

Old config, prior to v0.6.0:
optim:
  args:
    betas:
    - 0.9
    - 0.98
    eps: 1.0e-09
    label_smoothing: 0.1
    lr: 0.1
    warmup_steps: 4000
    amsgrad: false
    weight_decay: 0
    criterion: smooth_kld
    inv_sqrt: false
    constant: 2
  name: ADAM
New config in v0.6.0
optimizer:
  name: adam
  args:
    betas:
    - 0.9
    - 0.98
    eps: 1.0e-09
    lr: 0.1

schedule:
  name: noam
  args:
    constant: 2
    warmup: 4000
    model_dim: 512

criterion:
  name: smooth_kld
  args:
    label_smoothing: 0.1

5. RTG CLI

All the below CLI tools give you finer control to go step by step if you want to test only a part of the pipeline. For end usage of the RTG toolkit, the workflow should be as simple as:

  1. Edit the conf.yml file

  2. Run the pipeline using python -m rtg.pipeline or rtg-pipe command

  3. Occasionally, to decode newer tests files that were not listed in conf.yml, use python -m rtg.decode or rtg-decode

5.1. Summary:

The following command line tools are added when rtg is installed using pip.

Table 6. Table Summary of CLI tools

Command

Purpose

rtg-pipe

Run rtg-prep, rtg-train and test case evaluation

rtg-decode

Decode new source files using the values set in conf.yml

rtg-export

Export an experiment

rtg-fork

Fork an experiment with/without same conf, code, data, vocabularies etc

rtg-serve

Serve an RTG model over HTTP API using Flask server

rtg-decode-pro

Decode new source files using the values that you supply from CLI args

rtg-prep

Prepare an experiment. You should be using rtg-pipe

rtg-train

Train a model. You should be using rtg-pipe

rtg-syscomb

System combination. Dont bother about it for now.

rtg-launch

Launch data distributed training

rtg-params

Show parameters in model

5.2. rtg-pipe: Pipeline

This is the CLI interface that most likely use.

$ python -m rtg.pipeline -h

usage: rtg.prep [-h] [-G] exp [conf]

prepare NMT experiment

positional arguments:
  exp             Working directory of experiment
  conf            Config File. By default <work_dir>/conf.yml is used

optional arguments:
  -h, --help      show this help message and exit
  -G, --gpu-only  Crash if no GPU is available

5.3. rtg-prep: Prepare an experiment

    $ python -m rtg.prep -h
    usage: rtg.prep [-h] work_dir [conf_file]

    prepare NMT experiment

    positional arguments:
      work_dir    Working directory
      conf_file   Config File. By default <work_dir>/conf.yml is used

    optional arguments:
      -h, --help  show this help message and exit

5.4. rtg-train : Train a Model

    $ python -m rtg.train -h
    usage: rtg.train [-h] [-rs SEED] [-st STEPS] [-cp CHECK_POINT]
                     [-km KEEP_MODELS] [-bs BATCH_SIZE] [-op {ADAM,SGD}]
                     [-oa OPTIM_ARGS] [-ft]
                     work_dir

    Train NMT model

    positional arguments:
      work_dir              Working directory

    optional arguments:
      -h, --help            show this help message and exit
      -rs SEED, --seed SEED
                            Seed for random number generator. Set it to zero to
                            not touch this part. (default: 0)
      -st STEPS, --steps STEPS
                            Total steps (default: 128000)
      -cp CHECK_POINT, --check-point CHECK_POINT
                            Store model after every --check-point steps (default:
                            1000)
      -km KEEP_MODELS, --keep-models KEEP_MODELS
                            Number of checkpoints to keep. (default: 10)
      -bs BATCH_SIZE, --batch-size BATCH_SIZE
                            Mini batch size of training and validation (default:
                            256)
      -op {ADAM,SGD}, --optim {ADAM,SGD}
                            Name of optimizer (default: ADAM)
      -oa OPTIM_ARGS, --optim-args OPTIM_ARGS
                            Comma separated key1=val1,key2=val2 args to optimizer.
                            Example: lr=0.01,warmup_steps=1000 The arguments
                            depends on the choice of --optim (default: lr=0.001)
      -ft, --fine-tune      Use fine tune corpus instead of train corpus.
                            (default: False)

5.5. rtg-decode: Decoder

usage: rtg.decode [-h] [-if [INPUT [INPUT ...]]] [-of [OUTPUT [OUTPUT ...]]]
                  [-sc] [-b BATCH_SIZE] [-msl MAX_SRC_LEN] [-nb]
                  exp_dir

Decode using NMT model

positional arguments:
  exp_dir               Experiment directory

optional arguments:
  -h, --help            show this help message and exit
  -if [INPUT [INPUT ...]], --input [INPUT [INPUT ...]]
                        Input file path. default is STDIN (default:
                        [<_io.TextIOWrapper name='<stdin>' encoding='utf-8'>])
  -of [OUTPUT [OUTPUT ...]], --output [OUTPUT [OUTPUT ...]]
                        Output File path. default is STDOUT (default:
                        [<_io.TextIOWrapper name='<stdout>'
                        encoding='utf-8'>])
  -sc, --skip-check     Skip Checking whether the experiment dir is prepared
                        and trained (default: False)
  -b BATCH_SIZE, --batch-size BATCH_SIZE
                        batch size for 1 beam. effective_batch =
                        batch_size/beam_size (default: None)
  -msl MAX_SRC_LEN, --max-src-len MAX_SRC_LEN
                        max source len; longer seqs will be truncated
                        (default: None)
  -nb, --no-buffer      Processes one line per batch followed by flush output
                        (default: False)

5.6. rtg-decode-pro: Pro Decoder

Note: for simple use with defauls from conf.yml, use rtg-decode or python -m rtg.decode.

    $ python -m rtg.decode_pro -h
    usage: rtg.decode [-h] [-if INPUT] [-of OUTPUT] [-bs BEAM_SIZE] [-ml MAX_LEN]
                      [-nh NUM_HYP] [--prepared]
                      [-bp {E1D1,E2D2,E1D2E2D1,E2D2E1D2,E1D2,E2D1}] [-it] [-sc]
                      [-en ENSEMBLE] [-cb SYS_COMB]
                      work_dir [model_path [model_path ...]]

    Decode using NMT model

    positional arguments:
      work_dir              Working directory
      model_path            Path to model's checkpoint. If not specified, a best
                            model (based on the score on validation set) from the
                            experiment directory will be used. If multiple paths
                            are specified, then an ensembling is performed by
                            averaging the param weights (default: None)

    optional arguments:
      -h, --help            show this help message and exit
      -if INPUT, --input INPUT
                            Input file path. default is STDIN (default:
                            <_io.TextIOWrapper name='<stdin>' mode='r'
                            encoding='UTF-8'>)
      -of OUTPUT, --output OUTPUT
                            Output File path. default is STDOUT (default:
                            <_io.TextIOWrapper name='<stdout>' mode='w'
                            encoding='UTF-8'>)
      -bs BEAM_SIZE, --beam-size BEAM_SIZE
                            Beam size. beam_size=1 is greedy, In theory: higher
                            beam is better approximation but expensive. But in
                            practice, higher beam doesnt always increase.
                            (default: 5)
      -ml MAX_LEN, --max-len MAX_LEN
                            Maximum output sequence length (default: 100)
      -nh NUM_HYP, --num-hyp NUM_HYP
                            Number of hypothesis to output. This should be smaller
                            than beam_size (default: 1)
      --prepared            Each token is a valid integer which is an index to
                            embedding, so skip indexifying again (default: False)
      -bp {E1D1,E2D2,E1D2E2D1,E2D2E1D2,E1D2,E2D1}, --binmt-path {E1D1,E2D2,E1D2E2D1,E2D2E1D2,E1D2,E2D1}
                            Sub module path inside BiNMT. applicable only when
                            model is BiNMT (default: None)
      -it, --interactive    Open interactive shell with decoder (default: False)
      -sc, --skip-check     Skip Checking whether the experiment dir is prepared
                            and trained (default: False)
      -en ENSEMBLE, --ensemble ENSEMBLE
                            Ensemble best --ensemble models by averaging them
                            (default: 1)
      -cb SYS_COMB, --sys-comb SYS_COMB
                            System combine models at the softmax layer using the
                            weights specified in this file. When this argument is
                            supplied, model_path argument is ignored. (default:
                            None)

5.7. rtg-fork: Fork an experiment

usage: rtg-fork [-h] [--conf | --no-conf] [--data | --no-data]
                [--vocab | --no-vocab] [--code | --no-code]
                EXP_DIR TO_DIR

fork an experiment.

positional arguments:
  EXP_DIR     From experiment. Should be valid experiment dir
  TO_DIR      To experiment. This will be created.

optional arguments:
  -h, --help  show this help message and exit
  --conf      Copy config: from/conf.yml → to/conf.yml (default: True)
  --no-conf   Negation of --conf (default: False)
  --data      Link data dir . This includes vocab. (default: True)
  --no-data   Negation of --data (default: False)
  --vocab     copy vocabularies. dont use it with --data (default: False)
  --no-vocab  Negation of --vocab (default: True)
  --code      copy source code. (default: True)
  --no-code   Negation of --code (default: False)

5.8. rtg-export Export

Export an experiment:

    python -m rtg.export -h
    usage: export.py [-h] [-en ENSEMBLE] [-nm NAME] [--config | --no-config]
                     [--vocab | --no-vocab]
                     source target

    positional arguments:
      source                Path to experiment (source)
      target                Path to destination where the export should be

    optional arguments:
      -h, --help            show this help message and exit
      -en ENSEMBLE, --ensemble ENSEMBLE
                            Maximum number of checkpoints to average and export.
                            set 0 to disable (default: 5)
      -nm NAME, --name NAME
                            Name for the exported model (active when --ensemble >
                            0). Value should be a single word. This will be useful
                            if you are going to place multiple exports in a same
                            dir for system combination (default: None)
      --config              Copy config (default: True)
      --no-config           See --config (default: False)
      --vocab               Copy vocabulary files (such as sentence piece models)
                            (default: True)
      --no-vocab            See --vocab (default: False)

6. Other tools:

6.1. rtg-syscomb System Combiner

    python -m rtg.syscomb -h
    usage: __main__.py [-h] [-b BATCH_SIZE] [-s STEPS]
                       experiment models [models ...]

    positional arguments:
      experiment            Path to experiment directory
      models                Path to models

    optional arguments:
      -h, --help            show this help message and exit
      -b BATCH_SIZE, --batch-size BATCH_SIZE
                            Batch size (default: 128)
      -s STEPS, --steps STEPS
                            Training steps (default: 2000)

6.2. Perplexity

Compute perplexity of a language model on a test set.

    $ python -m rtg.eval.perplexity -h
    usage: rtg.eval.perplexity [-h] [-t TEST] [-en ENSEMBLE]
                           work_dir [model_path [model_path ...]]

    positional arguments:
    work_dir              Working/Experiment directory
    model_path            Path to model's checkpoint. If not specified, a best
                        model (based on the score on validation set) from the
                        experiment directory will be used. If multiple paths
                        are specified, then an ensembling is performed by
                        averaging the param weights (default: None)

    optional arguments:
    -h, --help            show this help message and exit
    -t TEST, --test TEST  test file path. default is STDIN (default:
                        <_io.TextIOWrapper name='<stdin>' mode='r'
                        encoding='UTF-8'>)
    -en ENSEMBLE, --ensemble ENSEMBLE
                        Ensemble best --ensemble models by averaging them
                        (default: 1)

6.3. Line Bleu

Computes BLEU per line

    python -m rtg.eval.linebleu -h
    usage: linebleu.py [-h] [-c CANDS] [-r REFS] [-n N] [-nr] [-nc] [-o OUT] [-v]

    Computes BLEU score per record.

    optional arguments:
      -h, --help            show this help message and exit
      -c CANDS, --cands CANDS
                            Candidate (aka output from NLG system) file (default:
                            <_io.TextIOWrapper name='<stdin>' mode='r'
                            encoding='UTF-8'>)
      -r REFS, --refs REFS  Reference (aka human label) file (default:
                            <_io.TextIOWrapper name='<stdin>' mode='r'
                            encoding='UTF-8'>)
      -n N, --n N           maximum n as in ngram. (default: 4)
      -nr, --no-refs        Do not write references to --out (default: False)
      -nc, --no-cands       Do not write candidates to --out (default: False)
      -o OUT, --out OUT     Output file path to store the result. (default:
                            <_io.TextIOWrapper name='<stdout>' mode='w'
                            encoding='UTF-8'>)
      -v, --verbose         verbose mode (default: False)

6.4. OOV

Compute Out-of-Vocabulary(OOV) rate

    $ python -m rtg.tool.oov -h
    usage: oov.py [-h] -tr TRAIN [-ts [TESTS [TESTS ...]]]

    optional arguments:
      -h, --help            show this help message and exit
      -tr TRAIN, --train TRAIN
                            Train file path (default: None)
      -ts [TESTS [TESTS ...]], --test [TESTS [TESTS ...]]
                            Test file paths (default: [<_io.TextIOWrapper
                            name='<stdin>' mode='r' encoding='UTF-8'>])

6.5. Class imbalance, Sequence lengths

Computes class Imbalance on training data and reports mean and median sequence lengths Get the stats reported in Gowda and May 's Neural Machine Translation with Imbalanced Classes

$ python -m rtg.eval.imbalance -h
usage: imbalance.py [-h] exp

positional arguments:
  exp         Path to experiment directory

optional arguments:
  -h, --help  show this help message and exit

Example:

$ python -m rtg.eval.imbalance runs/001-tfm
Experiment: runs/001-tfm shared_vocab:True
src types: 500 toks: 2,062,912 len_mean: 15.8686 len_median: 15.0 imbalance: 0.4409
tgt types: 500 toks: 1,711,685 len_mean: 13.1668 len_median: 12.0 imbalance: 0.4632
n_segs: 130,000

7. Environment Variables

7.1. GPUs

By default, RTG uses all GPUs specified by CUDA_VISIBLE_DEVICES environment variable.

To check if GPU is configured correctly,

python -c 'import torch; print(torch.cuda.is_available(), torch.cuda.device_count())'
# prints True and number_of_gpus

You can specify multiple GPUS, say devices with ids 0 and 1

export CUDA_VISIBLE_DEVICES=0,1

To disable GPU usage, simply set empty string to the variable or unset it

export CUDA_VISIBLE_DEVICES=
unset CUDA_VISIBLE_DEVICES

7.2. Fast Temporary FileSystem

When shared compute grids with network file systems (NFS) are used, the disk IO can be too slow. It helps to move training data that is frequently read to a fast temporary file system. Placing training data on TMPFS can be a good thing to do in this situation. export RTG_TMP to the desired path such as $TMPDIR before starting rtg process.

export RTG_TMP=$TMPDIR

The RTG_TMP does NOT have to be unique to each directory. So you can use the same directory for all the experiments.

Note: the model checkpoints dont use TMPDIR as of now. Since the checkpoints are taken once for every 1000 steps or so, it should be okay for now. But if it is a problem that needs to be addressed we shall revise this decision again.

7.3. Number of CPU Cores

export RTG_CPUS=10     #$SLURM_CPUS_ON_NODE
export OMP_NUM_THREADS=$RTG_CPUS
export MKL_NUM_THREADS=$RTG_CPUS

For scaling to large datasets, see "Scaling to Big Datasets Using PySpark" section.

8. Distributed Data Parallel (DDP)

This is a new feature to RTG and not all edge cases are tested.

rtg.distrib.launch simplifies the use of torch.distributed.launch as follows:

$ python -m rtg.distrib.launch -h
usage: launch.py [-h] [-N NODES] [-r NODE_RANK] [-P PROCS_PER_NODE]
                 [-G GPUS_PER_PROC] [--master-addr MASTER_ADDR]
                 [--master-port MASTER_PORT] [-m | --no_python]
                 training_script ...

PyTorch distributed training launch helper utilty that will spawn up multiple
distributed processes

positional arguments:
  training_script       The full path to the single GPU training
                        program/script to be launched in parallel, followed by
                        all the arguments for the training script
  training_script_args

optional arguments:
  -h, --help            show this help message and exit
  -N NODES, --nodes NODES
                        The number of nodes to use for distributed training
                        (default: 1)
  -r NODE_RANK, --node-rank NODE_RANK
                        The rank of the node for multi-node distributed
                        training (default: 0)
  -P PROCS_PER_NODE, --procs-per-node PROCS_PER_NODE
                        The number of processes to launch on each node with
                        one gpu each, for GPU training, this is recommended to
                        be set to the number of GPUs in your system so that
                        each process can be bound to a single GPU. (default:
                        1)
  -G GPUS_PER_PROC, --gpus-per-proc GPUS_PER_PROC
                        Number of GPUs to assign to each process. (default: 0)
  --master-addr MASTER_ADDR
                        Master node (rank 0)'s address, should be either the
                        IP address or the hostname of node 0, for single node
                        multi-proc training, the --master_addr can simply be
                        127.0.0.1 (default: 127.0.0.1)
  --master-port MASTER_PORT
                        Master node (rank 0)'s free port that needs to be used
                        for communciation during distributed training
                        (default: 29500)
  -m, --module          Changes each process to interpret the launch script as
                        a python module, executing with the same behavior
                        as'python -m'. (default: False)
  --no_python           Do not prepend the training script with "python" -
                        just exec it directly. Useful when the script is not a
                        Python script. (default: False)

Examples

  1. Run on two CPU processes -P 2 on single node -N 1 (for testing, no GPUS -G 0)

    python -m rtg.distrib.launch -N 1 -P 2 -G 0 -m rtg.pipeline  runs/005-tfm-nldb
  2. Run on on single node, two processes, one GPU per process: -N 1 -P 2 -G 1

  3. Run on on two node, two processes each, one GPU per process: -N 2 -P 2 -G 1.

    # on first node: rank 0
    python -m rtg.distrib.launch -N 2 -r 0 -P 2 -G 1 -m rtg.pipeline runs/005-tfm-nldb -G
    # on second node: rank 1
    python -m rtg.distrib.launch -N 2 -r 1 -P 2 -G 1 -m rtg.pipeline  runs/005-tfm-nldb -G

WARNING:

  1. Don’t ever use -G 2 or more (i.e. dont use 2 or more GPUs per process), instead use more -P (i.e. more processes with 1 GPU each.

9. FP16, Mixed Precision Training

Note that rtg-pipe -h has -fp16, --fp16 CLI argument flag that can be used to enable mixed precision training.

$ rtg-pipe <experiment-dir> --fp16

10. Gradient Clipping

Gradient clipping is supported using torch.clip_grad_norm_.

trainer.init_args.clip_grad_norm is treated as maximum L2 norm at which gradients are clipped.

trainer:
  init_args:
    # grad_accum: 1   # other params for init_args are allowed
    clip_grad_norm: 8

11. Scaling Big Using PySpark

When dealing with big datasets, the traditional tools such as multiprocessing and SQLite3 simply aren’t enogh. In such scenario, PySpark is a useful backend to use. When pyspark is enabled

PySpark is used to

  • compute term frequencies which help speed up BPE learning

  • encode data with BPE

  • store data in NLCodec MultipartDb

To enable pyspark backend

  1. Install pyspark; eg: pip install pyspark >= 3.0.0. Make sure you have a JVM/JDK that is compatible for pyspark.

  2. For data preparation, only codec_lib: nlcodec supports pyspark backend as of now. If you are using sentencepiece, switch to nlcodec

  3. Add the spark block to the top level of conf.yml. See experiments/spark-bigdataprep.html for a full example.

prep:
   codec_lib: nlcodec    # only nlcodec supports pyspark backend
   max_part_size: 1000000  # part size (num of recs); divides the training data into multiple parts
   ... # other args
spark: # add this block to enable spark backend
  # double quote the keys containing dot
  "spark.master": local[3]           # set it to local[*] to use all local CPUs
  "spark.app.name": RTG NMT on Spark  # Name for the App
  "spark.driver.memory": 6g
  #key1: value1    # any other spark configs you want to control

You may use local[*] to use all CPUs however, it is important to note that:

  1. If you have too many CPU cores (say 32 or more), the disk/storage may have too much pressure and the overall performance could degrade.

  2. Remember to exclude some CPUs out from spark for other work load, such as for pytorch. In the above example I used local[3] because I had only 4 CPUs in total and excluded one from spark.

Watch out the spark logs for any warning messages. Also, the log message provides the Spark web UI address when spark session is initialized.

The multi-node spark distributed mode is not tested. (But it might work out of the box if "spark.master" is correctly set)

12. RTG Serve

RTG model can be served using Flask Server.

12.1. Flask Installation

$ pip install rtg[serve]

Flask has its own set of dependencies unrelated to the core functionality, hence, not installed when installing rtg.

12.2. Running

$ python -m rtg.serve -h  # rtg-serve
usage: rtg.serve [-h] [-d] [-p PORT] [-ho HOST] [-msl MAX_SRC_LEN] exp_dir

Deploy an RTG model to a RESTful server

positional arguments:
  exp_dir               Experiment directory

optional arguments:
  -h, --help            show this help message and exit
  -d, --debug           Run Flask server in debug mode (default: False)
  -p PORT, --port PORT  port to run server on (default: 6060)
  -ho HOST, --host HOST
                        Host address to bind. (default: 0.0.0.0)
  -b BASE, --base BASE  Base prefix path for all the URLs (default: None)
  -msl MAX_SRC_LEN, --max-src-len MAX_SRC_LEN
                        max source len; longer seqs will be truncated
                        (default: 250)

To launch a service for runs/001-tfm experiment, run python -m rtg.serve -d runs/001-tfm To use basepath of /v1: python -m rtg.serve -d runs/001-tfm -b /v1

It prints : * Running on 0.0.0.0:6060/ (Press CTRL+C to quit)

Currently only /translate API is supported. It accepts both GET with query params and POST with form params.

batch decoding is yet to be supported. The current decoder decodes only one sentence at a time.

An example POST request:

 curl --data "source=Comment allez-vous?" --data "source=Bonne journée" http://localhost:6060/translate
{
  "source": [
    "Comment allez-vous?",
    "Bonne journée"
  ],
  "translation": [
    "How are you?",
    "Have a nice day"
  ],
  "dec_args": {
    "beam_size": 4,
    "lp_alpha": 0.6,
    "max_len": 50,
    "num_hyp": 1
  },
  "score": [
    -6,
    -3
  ],
  "time": 4.5281,
  "time_unit": "s",
}

You can also request like GET method as localhost:6060/translate?source=text1&source=text2 after proper URL encoding the text1 text2. This should only be used for quick testing in your web browser.

12.3. Advanced Decoder Args

You may pass the following optional arguments to API:

  • beam_size - Number of beams to use for decoding

  • num_hyp - Number of hypotheses to return in response

  • max_len - Maximum length (relative to source length) to wait for end-of-seq token

  • lp_alpha - Length penalty

All these arguments take default values from conf.yml, but you may also set at runtime via arguments to rest API. Example:

curl --data "source=Comment allez-vous?" --data "source=Bonne journée" "http://localhost:6060/translate?beam_size=6&num_hyp=4&lp_alpha=0.0"
{
  "dec_args": {
    "beam_size": 6,
    "lp_alpha": 0,
    "max_len": 50,
    "num_hyp": 4
  },
  "source": [
    "Comment allez @-@ vous ?",
    "Bonne journée"
  ],
  "time": 6.4446,
  "time_unit": "s",
  "translation": [
    [
      "How do you do, sir?.",
      "- How are you? - Fine.",
      "How do you do?",
      "How do you do?"
    ],
    [
      "Have a nice day.",
      "Have a good day",
      "Good day",
      "Have a good day.."
    ]
  ],
  "score": [
    [
      -8.3406,
      -8.3871,
      -9.1363,
      -9.1478
    ],
    [
      -3.7928,
      -3.8259,
      -3.8653,
      -3.8789
    ]
  ],
}

12.4. Google Analytics Integration

Google Analytics is supported on web pages, however disabled by default. To enable set GA_TAG environment variable before starting rtg.serve process.

export GA_TAG="G-xxxxx"

Production Deployment Please use uWSGI for production deployment. If you dont already have uWSGI, you may install it via conda by running conda install -c conda-forge uwsgi.

uwsgi --http 127.0.0.1:5000 --module rtg.serve.app:app --pyargv "<path-to-exp-dir>"

# or using a .ini file
uwsgi --ini examples/uwsgi.ini

Where the uwsgi.ini has the following info:

[uwsgi]
http = 0.0.0.0:6060
module = rtg.serve.app:app
pyargv = /full/path/<path-to-exp-dir> -b /v1
master = true
processes = 1
stats = 127.0.0.1:9191

Note that <path-to-exp-dir> is expected to be a valid path to Experiment dir, it maybe obtained using rtg-export tool.

13. Pre-process and post-process

The input/source text given to the API must be pre-processed in the same settings as the preprocessing during training phase. So, we offer configurations to match the preprocessing:

  • src_pre_proc: List of transformations to be used on source text before giving to model (e.g. tokenizer, lowercase)

  • tgt_pre_proc: List of transformations to be used on target text before giving to model (e.g. tokenizer, lowercase)

  • tgt_post_proc: List of transformations to be used on target text produced by model (e.g. detokenizer, removal of unk)

The following transformations are built into RTG, so you may simply use their name:

transformers  = {
    'no_op': lambda x: x,
    'space_tok': lambda x: ' '.join(x.strip().split()),  # removes extra white spaces
    'space_detok': lambda toks: ' '.join(toks),
    'moses_tok': partial(MosesTokenizer().tokenize, escape=False, return_str=True,
                         aggressive_dash_splits=True,
                         protected_patterns=MosesTokenizer.WEB_PROTECTED_PATTERNS),
    'moses_detok': partial(MosesDetokenizer().detokenize, return_str=True, unescape=True),
    'moses_truecase': partial(MosesTruecaser().truecase, return_str=True),
    'lowercase': lambda x: x.lower(),
    'drop_unk': lambda x: x.replace('<unk>', ''),
    'html_unescape': html.unescape,
    'punct_norm': MosesPunctNormalizer().normalize
}

When no arguments are given to {src_pre,tgt_pre,tgt_prop}_proc are missing, we use the same sensible defaults (same as the ones used in aclanthology.org/2021.acl-demo.37/.)

src_pre_proc:
  - html_unescape
  - punct_norm
  - moses_tok
tgt_post_proc:
  - moses_detok
  - drop_unk

You may also use shell command line, including unix pipes, by prefixing your command with "#!". In addition, you may mix shell commands with known (pythonic) transforms. Example:

prep:
  src_pre_proc:
    - "#!/path/to/normalizer.perl | /path/to/tokenizer.py --lang deu"
    - lowercase
  tgt_post_proc:
    - drop_unk
    - moses_detok
Disabling pre- and post- processing
  • You may permanently disable preprocessing and post processing using

prep:
  src_pre_proc:
    - no_op
  tgt_post_proc:
    - no_op

NOTE: {src,tgt}_pre_proc and tgt_post_proc are only used by REST API as of now. rtg.decode and rtg.prep do not yet to use pre- and post- text transformers.

14. Development Environment:

14.1. Run Tests

Test cases are done using the pytest framework. It can be installed using pip install pytest

All tests should be run from the root dir of the project.

  1. Run all tests, with minimal outputs: python -m pytest

  2. Run all tests and also see STDOUT/STDERR python -m pytest -s

  3. Run a specific test such as python -m pytest -s -k 'test_pipeline_transformer'

IDE Setup:

14.2. Adding a new model

  1. Go to rtg.module and create a new .py file

  2. Create a class say MyModel that extends either NMTModel or LangModel or Model depending on its type

    1. it should implement required methods including make_model(cls, …​), make_trainer and property model_type; And also make_generator for generation models such as NMT and LangModel

    2. make_model is a model factory, used to create an instance of model. This would be unique to your model.

    3. make_trainer is a trainer factory, used for creating a trainer. Often, one of the existing (e.g. parent) model’s trainer would be sufficient, but if your model requires a different training strategy or has different training API, then you may want to create a new one.

    4. generators used by decoder, to run model in eval model

    5. refer to tfmnmt for an example

  3. register the model_type in the registry.py

from rtg.registry import register, MODEL
from rtg.module import NMTModel

@register(kind=MODEL,name='newnmt')
class NewNMTModel(NMTModel):

    @classmethod
    def make_model(cls, exp, *args, **kwargs):
        pass

    @classmethod
    def make_trainer(cls, *args, **kwargs):
        pass

    @classmethod
    def make_generator(cls, *args, **kwargs):
        pass
If your model is similar to existing model and match most of its API, you should reuse one of those trainers and generators.

15. PyPI Release Instructions

15.1. Steps:

  1. Update the version: _version_ in rtg/_init_.py

  2. Remove old builds (if any)

    rm -r build dist *.egg-info
  3. Build:

    python setup.py sdist bdist_wheel
  4. Upload to testpypi

    twine upload -r testpypi dist/*
  5. Make docs and link

    1. List a new version for docs: docs/versions.adoc

    2. Build docs

      docs/make-docs.sh      # docs
  6. Upload to pypi

    twine upload -r pypi dist/*

15.2. The .pypirc file

The rc file ~/.pypirc should have something like this

[distutils]
index-servers =
    pypi
    testpypi

[pypi]
repository: https://upload.pypi.org/legacy/
username:Thamme.Gowda
password:<password_here>

[testpypi]
repository: https://test.pypi.org/legacy/
username:Thamme.Gowda
password:<password_here>

For the first time users of PyPI, you need to create an account at pypi.org/ AND test.pypi.org/. Yes, they are two different accounts! Make your life easy by using same userID and password.

16. Docker Release

16.1. Docker for ARM64

With Apple moving to ARM chips, running AMD64 docker images on ARM is not efficient.

cd dockers
# find a suitable docker file to build; e.g., this one
docker build . -f rtg-0.7-py39_tr110_cu114.dockerfile-t tgowda/rtg:0.7-py39_tr110_cu114

Acknowledgements

The research is based upon work supported by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA), via AFRL Contract #FA8650-17-C-9116. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of the ODNI, IARPA, or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright annotation thereon.

This material is based on research sponsored by Air Force Research Laboratory (AFRL) under agreement number FA8750-19-1-1000. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation therein.