In [None]:
from bertviz import model_view, head_view
from transformers import *

import numpy as np
import pprint

# Get the interactive Tools for Matplotlib
%matplotlib notebook
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from transformers import BertTokenizer, BertModel
import torch

In [None]:
model_path = 'bert-base-uncased'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path,  output_hidden_states=True, output_attentions=True)  
model = AutoModel.from_pretrained(model_path, config=config).to(device)

# Creation of dummy (query,doc) pairs -> Replace by the actual (query, doc) pairs

In [None]:
import random

# Sample queries
query_a = "Is throat cancer treatable nowadays?"
query_b = "How to deal with hypothermia?"

# Sample documents
sentence_a = "58-year-old woman with hypertension"
sentence_b = "BACKGROUND : Longitudinal studies hypertension"

number_documents = 3600

# Create dummy (query, document) pairs
# TODO: replace by your own (query, document) pairs
query_pairs = [(random.choice([query_a, query_b]), 
           random.choice([sentence_a, sentence_b])) for i in range(number_documents)]

## CLS Embedding Extraction in Batches


In [None]:
import numpy as np
    
def extract_cls(query_pairs, embeddings, batch_size=32):

    # Iterate over all documents, in batches of size <batch_size>
    for batch_idx in range(0, len(query_pairs), batch_size):

        # Get the current batch of samples
        batch_data = query_pairs[batch_idx:batch_idx + batch_size]

        inputs = tokenizer.batch_encode_plus(batch_data, 
                                       return_tensors='pt',  # pytorch tensors
                                       add_special_tokens=True,  # Add CLS and SEP tokens
                                       max_length = 512, # Max sequence length
                                       truncation = True, # Truncate if sequences exceed the Max Sequence length
                                       padding = True) # Add padding to forward sequences with different lengths
        
        # Forward the batch of (query, doc) sequences
        with torch.no_grad():
            inputs.to(device)
            outputs = model(**inputs)

        # Get the CLS embeddings for each pair query, document
        batch_cls = outputs['hidden_states'][-1][:,0,:]
        
        # L2-Normalize CLS embeddings. Embeddings norm will be 1.
        batch_cls = torch.nn.functional.normalize(batch_cls, p=2, dim=1)
        
        # Store the extracted CLS embeddings from the batch on the memory-mapped ndarray
        embeddings[batch_idx:batch_idx + batch_size] = batch_cls.cpu()
        
    return embeddings

The code below will extract CLS embeddings for all query_pairs.

In [None]:
# Numpy ndarray that will store (in RAM) the CLS embeddings of each (query, doc) pair
embeddings = np.zeros((len(query_pairs), 768))

embeddings = extract_cls(query_pairs, embeddings=embeddings, batch_size=32)
print(embeddings.shape)

## If you're running into memory issues - Put Numpy Arrays on disk

The code below extracts CLS embeddings for all query_pairs and stores them on disk, using persistent Numpy ndarrays. The difference is that they won't be stored on your computer RAM. After creating the array as shown below, the fact that they are on disk is abstracted, and you can use them as you would do with standard numpy arrays.

Since it writes to disk, it will be slower than the first option, but the amount of RAM needed will be dramatically reduced.

Reference: https://numpy.org/doc/stable/reference/generated/numpy.memmap.html

In [None]:
filename = "cls_embeddings.dat"

# Create a memory-mapped numpy array. The array is stored on disk, not on RAM
# The shape argument must match (total number query-doc pairs, CLS embedding size)
embeddings_persistent = np.memmap(filename, dtype='float32', mode='w+', shape=(len(query_pairs), 768))

embeddings_persistent = extract_cls(query_pairs, embeddings=embeddings_persistent, batch_size=32)
print(embeddings_persistent.shape)