Source code for gatelfdata.vocabs
"""Module for the Voabs class"""
import logging
from collections import defaultdict
from gatelfdata.vocab import Vocab
import sys
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
streamhandler = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
'%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
streamhandler.setFormatter(formatter)
logger.addHandler(streamhandler)
[docs]class Vocabs(object):
"""A class for managing all the vocab instances that are needed by features"""
def __init__(self, remove_counts=True, remove_embs=True):
"""
Create a vocabs instance and set the default behaviour when finishing each
vocab.
:param remove_counts: remove the counts per word data when finishing
:param remove_embs: remove the embs per word data when finishing
"""
# map from embedding id to vocab instance
self.remove_counts = remove_counts
self.remove_embs = remove_embs
self.vocabs = defaultdict()
[docs] def setup_vocab(self, attrinfo, featurestats):
"""Create or update the temporary Vocab instances so that the counts from different attributes get merged"""
logger.debug("Pre-initialising vocab for %r", attrinfo)
counts = featurestats.get("stringCounts")
if counts:
emb_id = attrinfo.get("emb_id")
if emb_id in self.vocabs:
vocab = self.vocabs.get(emb_id)
vocab.add_counts(counts)
else:
emb_train = attrinfo.get("emb_train")
emb_file = attrinfo.get("emb_file")
emb_dims = attrinfo.get("emb_dims")
emb_minfreq = attrinfo.get("emb_minfreq")
vocab = Vocab(featurestats["stringCounts"],
emb_id=emb_id, emb_train=emb_train, emb_file=emb_file, emb_dims=emb_dims,
emb_minfreq=emb_minfreq)
self.vocabs[emb_id] = vocab
[docs] def finish(self):
"""Once all the counts have been gathered, create the final instances"""
for _, vocab in self.vocabs.items():
vocab.finish(remove_counts=self.remove_counts, remove_embs=self.remove_embs)
logger.debug("Finished vocabs: %s" % (self.vocabs,))
[docs] def get_vocab(self, attrinfo_or_embid):
"""Return a vocab instance for the given attribute name or embedding id."""
if isinstance(attrinfo_or_embid, dict):
emb_id = attrinfo_or_embid.get("emb_id")
else:
emb_id = attrinfo_or_embid
if emb_id in self.vocabs:
return self.vocabs.get(emb_id)
else:
raise Exception("No vocab for emb_id: %s got %s" % (emb_id, self.vocabs.keys()))
def __str__(self):
return self.__repr__()
def __repr__(self):
return "Vocabs(vocabs=%r)" % (self.vocabs.keys())