Source code for gatelfpytorchjson.takefromtuple
import torch.nn
import sys
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
streamhandler = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
'%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
streamhandler.setFormatter(formatter)
logger.addHandler(streamhandler)
[docs]class TakeFromTuple(torch.nn.Module):
def __init__(self, moduletowrap, which=0):
"""Wrap the model (e.g. LSTM) and make sure that only the which part of the
tuple it creates is returned.
"""
super().__init__()
self.module = moduletowrap
[docs] def forward(self, vals):
ret = self.module(vals)
if isinstance(ret, tuple):
return ret[0]
else:
return ret