# Reranking with BERT 


## Install the libraries
First you need to install the following libraries:

    pip install transformers
    pip install ipywidgets
    pip install bertviz

Once everything is installed you can download 

In [3]:
import sys
!test -d bertviz_repo && echo "FYI: bertviz_repo directory already exists, to pull latest version uncomment this line: !rm -r bertviz_repo"
# !rm -r bertviz_repo # Uncomment if you need a clean pull from repo
!test -d bertviz_repo || git clone https://github.com/jessevig/bertviz bertviz_repo
if not 'bertviz_repo' in sys.path:
  sys.path += ['bertviz_repo']

FYI: bertviz_repo directory already exists, to pull latest version uncomment this line: !rm -r bertviz_repo


## Imports and definitions

In [4]:
from bertviz import model_view, head_view
from transformers import *

import numpy as np
import pprint

# Get the interactive Tools for Matplotlib
%matplotlib notebook
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from transformers import BertTokenizer, BertModel
import torch

In [5]:
#model_path = 'nboost/pt-bert-base-uncased-msmarco'
model_path = 'bert-base-uncased'

CLS_token = "[CLS]"
SEP_token = "[SEP]"


# Load the required tokenizer, configuration and model

In [6]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
config = AutoConfig.from_pretrained('bert-base-uncased',  output_hidden_states=True, output_attentions=True)  
model = AutoModel.from_pretrained('bert-base-uncased', config=config)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Neural ranking

The next sentence prediction task is closely related to the ranking task, where the first sentence is the query and the second sentence is the relevant document.

The next sentence prediction task uses the \[CLS\] output embedding to make the prediction.

In [7]:
query = "What is covid 19 ?"
document = "Covid 19 is an infectious disease caused by the SARS-CoV-2 virus."

In [13]:
# Generate the input sequence 
inputs_qa = tokenizer.encode_plus(query, document, return_tensors='pt', add_special_tokens=True, max_length = 512, truncation = True)
print(inputs_qa)

{'input_ids': tensor([[  101,  2054,  2003,  2522, 17258,  2539,  1029,   102,  2522, 17258,
          2539,  2003,  2019, 16514,  4295,  3303,  2011,  1996, 18906,  2015,
          1011,  2522,  2615,  1011,  1016,  7865,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]])}


In [12]:
# Check the generated input sequence
decoded_input_qa = tokenizer.decode(inputs_qa["input_ids"][0])
print(decoded_input_qa)

[CLS] what is covid 19? [SEP] covid 19 is an infectious disease caused by the sars - cov - 2 virus. [SEP]


In [14]:
outputs_qa = model(**inputs_qa)

In [17]:
# The CLS embedding on the last layer is designed to feed a sigmoid function (logistic regression)
cls_embedding = outputs_qa["last_hidden_state"][0,0]
print(cls_embedding.shape)

torch.Size([768])


In [18]:
# The CLS embedding of a (q,d) pair replaces the many individual retrieval models
# However, its values are not interpretable like the doc scores of classic retrieval models