Python Class for RNN Encoder
class RNNEncoder(FairseqEncoder): def init(self, args, dictionary, embedtokens): super()._init(dictionary) self.embed_tokens = embed_tokens
self.embed_dim = args.encoder_embed_dim
self.hidden_dim = args.encoder_ffn_embed_dim
self.num_layers = args.encoder_layers
self.padding_idx = dictionary.pad()
def combine_bidir(self, outs, bsz: int): out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() return out.view(self.num_layers, bsz, -1)
def forward(self, src_tokens, **unused): bsz, seqlen = src_tokens.size()
# get embeddings
x = self.embed_tokens(src_tokens)
x = self.dropout_in_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
pass thru bidirectional RNN
h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)
x, final_hiddens = self.rnn(x, h0)
outputs = self.dropout_out_module(x)
# outputs = [sequence len, batch size, hid dim * directions]
# hidden = [num_layers * directions, batch size , hid dim]
Since Encoder is bidirectional, we need to concatenate the hidden states of two directions
final_hiddens = self.combine_bidir(final_hiddens, bsz)
# hidden = [num_layers x batch x num_directions*hidden]
def reorder_encoder_out(self, encoder_out, new_order):
# This is used by fairseq's beam search. How and why is not particularly important here.
return tuple(
(
encoder_out[0].index_select(1, new_order),
encoder_out[1].index_select(1, new_order),
encoder_out[2].index_select(1, new_order),
)
)