Source code for neurox.data.extraction.transformers_extractor

"""Representations Extractor for ``transformers`` toolkit models.

Module that given a file with input sentences and a ``transformers``
model, extracts representations from all layers of the model. The script
supports aggregation over sub-words created due to the tokenization of
the provided model.

Can also be invoked as a script as follows:
    ``python -m neurox.data.extraction.transformers_extractor``
"""

import argparse
import sys

import numpy as np
import torch


from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from neurox.data.writer import ActivationsWriter


[docs]def get_model_and_tokenizer(model_desc, device="cpu", random_weights=False): """ Automatically get the appropriate ``transformers`` model and tokenizer based on the model description Parameters ---------- model_desc : str Model description; can either be a model name like ``bert-base-uncased``, a comma separated list indicating <model>,<tokenizer> (since 1.0.8), or a path to a trained model device : str, optional Device to load the model on, cpu or gpu. Default is cpu. random_weights : bool, optional Whether the weights of the model should be randomized. Useful for analyses where one needs an untrained model. Returns ------- model : transformers model An instance of one of the transformers.modeling classes tokenizer : transformers tokenizer An instance of one of the transformers.tokenization classes """ model_desc = model_desc.split(",") if len(model_desc) == 1: model_name = model_desc[0] tokenizer_name = model_desc[0] else: model_name = model_desc[0] tokenizer_name = model_desc[1] model = AutoModel.from_pretrained(model_name, output_hidden_states=True).to(device) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if random_weights: print("Randomizing weights") model.init_weights() return model, tokenizer
[docs]def aggregate_repr(state, start, end, aggregation): """ Function that aggregates activations/embeddings over a span of subword tokens. This function will usually be called once per word. For example, if we had the sentence:: This is an example which is tokenized by BPE into:: this is an ex @@am @@ple The function should be called 4 times:: aggregate_repr(state, 0, 0, aggregation) aggregate_repr(state, 1, 1, aggregation) aggregate_repr(state, 2, 2, aggregation) aggregate_repr(state, 3, 5, aggregation) Returns a zero vector if end is less than start, i.e. the request is to aggregate over an empty slice. Parameters ---------- state : numpy.ndarray Matrix of size [ NUM_LAYERS x NUM_SUBWORD_TOKENS_IN_SENT x LAYER_DIM] start : int Index of the first subword of the word being processed end : int Index of the last subword of the word being processed aggregation : {'first', 'last', 'average'} Aggregation method for combining subword activations Returns ------- word_vector : numpy.ndarray Matrix of size [NUM_LAYERS x LAYER_DIM] """ if end < start: sys.stderr.write("WARNING: An empty slice of tokens was encountered. " + "This probably implies a special unicode character or text " + "encoding issue in your original data that was dropped by the " + "transformer model's tokenizer.\n") return np.zeros((state.shape[0], state.shape[2])) if aggregation == "first": return state[:, start, :] elif aggregation == "last": return state[:, end, :] elif aggregation == "average": return np.average(state[:, start : end + 1, :], axis=1)
[docs]def extract_sentence_representations( sentence, model, tokenizer, device="cpu", include_embeddings=True, aggregation="last", tokenization_counts={} ): """ Get representations for one sentence """ # this follows the HuggingFace API for transformers special_tokens = [ x for x in tokenizer.all_special_tokens if x != tokenizer.unk_token ] special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) original_tokens = sentence.split(" ") # Add a letter and space before each word since some tokenizers are space sensitive tmp_tokens = [ "a" + " " + x if x_idx != 0 else x for x_idx, x in enumerate(original_tokens) ] assert len(original_tokens) == len(tmp_tokens) with torch.no_grad(): # Get tokenization counts if not already available for token_idx, token in enumerate(tmp_tokens): tok_ids = [ x for x in tokenizer.encode(token) if x not in special_tokens_ids ] if token_idx != 0: # Ignore the first token (added letter) tok_ids = tok_ids[1:] if token in tokenization_counts: assert tokenization_counts[token] == len( tok_ids ), "Got different tokenization for already processed word" else: tokenization_counts[token] = len(tok_ids) ids = tokenizer.encode(sentence, truncation=True) input_ids = torch.tensor([ids]).to(device) # Hugging Face format: tuple of torch.FloatTensor of shape (batch_size, sequence_length, hidden_size) # Tuple has 13 elements for base model: embedding outputs + hidden states at each layer all_hidden_states = model(input_ids)[-1] if include_embeddings: all_hidden_states = [ hidden_states[0].cpu().numpy() for hidden_states in all_hidden_states ] else: all_hidden_states = [ hidden_states[0].cpu().numpy() for hidden_states in all_hidden_states[1:] ] all_hidden_states = np.array(all_hidden_states) print('Sentence : "%s"' % (sentence)) print("Original (%03d): %s" % (len(original_tokens), original_tokens)) print( "Tokenized (%03d): %s" % ( len(tokenizer.convert_ids_to_tokens(ids)), tokenizer.convert_ids_to_tokens(ids), ) ) # Remove special tokens ids_without_special_tokens = [x for x in ids if x not in special_tokens_ids] idx_without_special_tokens = [ t_i for t_i, x in enumerate(ids) if x not in special_tokens_ids ] filtered_ids = [ids[t_i] for t_i in idx_without_special_tokens] assert all_hidden_states.shape[1] == len(ids) all_hidden_states = all_hidden_states[:, idx_without_special_tokens, :] assert all_hidden_states.shape[1] == len(filtered_ids) print( "Filtered (%03d): %s" % ( len(tokenizer.convert_ids_to_tokens(filtered_ids)), tokenizer.convert_ids_to_tokens(filtered_ids), ) ) segmented_tokens = tokenizer.convert_ids_to_tokens(filtered_ids) # Perform actual subword aggregation/detokenization counter = 0 detokenized = [] final_hidden_states = np.zeros( (all_hidden_states.shape[0], len(original_tokens), all_hidden_states.shape[2]) ) inputs_truncated = False for token_idx, token in enumerate(tmp_tokens): current_word_start_idx = counter current_word_end_idx = counter + tokenization_counts[token] # Check for truncated hidden states in the case where the # original word was actually tokenized if (tokenization_counts[token] != 0 and current_word_start_idx >= all_hidden_states.shape[1]) \ or current_word_end_idx > all_hidden_states.shape[1]: final_hidden_states = final_hidden_states[:, :len(detokenized), :] inputs_truncated = True break final_hidden_states[:, len(detokenized), :] = aggregate_repr( all_hidden_states, current_word_start_idx, current_word_end_idx - 1, aggregation, ) detokenized.append( "".join(segmented_tokens[current_word_start_idx:current_word_end_idx]) ) counter += tokenization_counts[token] print("Detokenized (%03d): %s" % (len(detokenized), detokenized)) print("Counter: %d" % (counter)) if inputs_truncated: print("WARNING: Input truncated because of length, skipping check") else: assert counter == len(ids_without_special_tokens) assert len(detokenized) == len(original_tokens) print("===================================================================") return final_hidden_states, detokenized
[docs]def extract_representations( model_desc, input_corpus, output_file, device="cpu", aggregation="last", output_type="json", random_weights=False, ignore_embeddings=False, decompose_layers=False, filter_layers=None, ): print(f"Loading model: {model_desc}") model, tokenizer = get_model_and_tokenizer( model_desc, device=device, random_weights=random_weights ) print("Reading input corpus") def corpus_generator(input_corpus_path): with open(input_corpus_path, "r") as fp: for line in fp: yield line.strip() return print("Preparing output file") writer = ActivationsWriter.get_writer(output_file, filetype=output_type, decompose_layers=decompose_layers, filter_layers=filter_layers) print("Extracting representations from model") tokenization_counts = {} # Cache for tokenizer rules for sentence_idx, sentence in enumerate(corpus_generator(input_corpus)): hidden_states, extracted_words = extract_sentence_representations( sentence, model, tokenizer, device=device, include_embeddings=(not ignore_embeddings), aggregation=aggregation, tokenization_counts=tokenization_counts ) print("Hidden states: ", hidden_states.shape) print("# Extracted words: ", len(extracted_words)) writer.write_activations(sentence_idx, extracted_words, hidden_states) writer.close()
HDF5_SPECIAL_TOKENS = {".": "__DOT__", "/": "__SLASH__"}
[docs]def main(): parser = argparse.ArgumentParser() parser.add_argument("model_desc", help="Name of model") parser.add_argument( "input_corpus", help="Text file path with one sentence per line" ) parser.add_argument( "output_file", help="Output file path where extracted representations will be stored", ) parser.add_argument( "--aggregation", help="first, last or average aggregation for word representation in the case of subword segmentation", default="last", ) parser.add_argument("--disable_cuda", action="store_true") parser.add_argument("--ignore_embeddings", action="store_true") parser.add_argument( "--random_weights", action="store_true", help="generate representations from randomly initialized model", ) ActivationsWriter.add_writer_options(parser) args = parser.parse_args() assert args.aggregation in [ "average", "first", "last", ], "Invalid aggregation option, please specify first, average or last." assert not(args.filter_layers is not None and args.ignore_embeddings is True), "--filter_layers and --ignore_embeddings cannot be used at the same time" if not args.disable_cuda and torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") extract_representations( args.model_desc, args.input_corpus, args.output_file, device=device, aggregation=args.aggregation, output_type=args.output_type, random_weights=args.random_weights, ignore_embeddings=args.ignore_embeddings, decompose_layers=args.decompose_layers, filter_layers=args.filter_layers, )
if __name__ == "__main__": main()