# BERT Tutorial


## Install the libraries
First you need to install the following libraries:

    pip install transformers
    pip install ipywidgets
    pip install bertviz

Once everything is installed you can download 

In [None]:
import sys
!test -d bertviz_repo && echo "FYI: bertviz_repo directory already exists, to pull latest version uncomment this line: !rm -r bertviz_repo"
# !rm -r bertviz_repo # Uncomment if you need a clean pull from repo
!test -d bertviz_repo || git clone https://github.com/jessevig/bertviz bertviz_repo
if not 'bertviz_repo' in sys.path:
  sys.path += ['bertviz_repo']

## Imports and definitions

In [None]:
from bertviz import model_view, head_view
from transformers import *

import numpy as np
import pprint

# Get the interactive Tools for Matplotlib
%matplotlib notebook
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from transformers import BertTokenizer, BertModel
import torch

In [None]:
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

In [None]:
#model_path = 'nboost/pt-bert-base-uncased-msmarco'
model_path = 'bert-base-uncased'

CLS_token = "[CLS]"
SEP_token = "[SEP]"


# Load the required tokenizer, configuration and model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path,  output_hidden_states=True, output_attentions=True)  
model = AutoModel.from_pretrained(model_path, config=config)


# Tokenization

See here for details: https://huggingface.co/docs/transformers/tokenizer_summary

In [None]:
sentence_a = "Is throat cancer treatable nowadays?"
sentence_b = "Tell me about lung cancer."
sentence_a = "58-year-old woman with hypertension"
sentence_b = "BACKGROUND : Longitudinal studies hypertension"
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True, max_length = 512, truncation = True)
pprint.pprint(inputs)

In [None]:
print(tokenizer.decode(inputs["input_ids"][0].tolist()))

In [None]:
input_ids = inputs['input_ids']
pprint.pprint(input_ids[0].tolist())

In [None]:
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
pprint.pprint(tokens)

# Model inference output

In [None]:
with torch.no_grad():
    outputs = model(**inputs)

In [None]:
outputs.keys()

## Layer embeddings

In [None]:
# the last layer is the output embedding layer
output_embeddings = outputs['last_hidden_state']

In [None]:
token_throat = 2
token_lung = 11
# out[0][token]
throat_output_embedding = output_embeddings[0][token_throat]
throat_output_embedding

In [None]:
hidden_states = outputs['hidden_states']

In [None]:
#This is the output token embedding for the word throat
# hidden_states[layer][0][token])
layer = 0
throat_input_embedding = hidden_states[layer][0][token_throat]
throat_input_embedding

## Self-attention matrices

In [None]:
attention = outputs['attentions']
# The format of the attention tensor is:
# attention[layer][0][head][token1][token2]
layer = 3
head = 3

In [None]:
# this will given the attention from one token vs the other token
attention[layer][0][head][token_throat][token_lung]

In [None]:
# There's a softmax, so, the sum should be 1 
attention[layer][0][head][token_throat].sum()

In [None]:
attention[layer][0][head][token_throat].sum()

# Extract Word embeddings



In [None]:
def get_word_idx(sent: str, word: str):
    return sent.split(" ").index(word)

def get_word_vector(inputs, outputs, idx, layer):
    """Get a word vector by averaging the embeddings of 
       all word occurrences of that word in the input"""

    # get all token idxs that belong to the word of interest
    token_ids_word = np.where(np.array(inputs.word_ids()) == idx)
    word_tokens_output = outputs.hidden_states[layer][0][token_ids_word]

    return word_tokens_output.mean(dim=0)

# The code below converts the tokens into a space delimited string.
# This will allow computing in which position of the BERT input sequence a given word is.
sentence_a = tokenizer.decode(inputs["input_ids"][0].tolist()).replace("[CLS] ", '').replace(" [SEP]", '')
word = "hypertension"
idx = get_word_idx(sentence_a, word)
print("Input sequence:", sentence_a)
print("The word \"", word, "\" occurs in position", idx, "of the BERT input sequence.")

word_embedding = get_word_vector(inputs, outputs, idx, 4)


In [None]:
import torch
import re
from transformers import AutoTokenizer, AutoModel

def get_word_vector_from_ab(inputs, outputs, word, layer = '-1', ab = 'A'):
    """
    This method extracts a word embedding from the requested layer 
    for sentence_a or sentence_b. If the word is divided into tokens, 
    the word embedding will be the average of the corresponding token 
    embeddings.

    NOTE: If the same word occurs multiple times in the sentence, 
    this method returns the word embedding of the first occurrence.

    Keyword arguments:
        inputs -- input passed to the transformer
        outputs -- output of the transformer
        word -- target word
        layer -- layer from where the word embedding vector should 
        be extracted.
        ab -- should be 'A' or 'B' indication if the word embedding is to be extracted 
        from sentence_a or sentence_b, i.e., query or document.
    """
       
    sep_token = np.where(np.array(inputs["input_ids"][0].tolist()) == 102)[0][0]
    if ab == 'A':
        tokens_a = inputs["input_ids"][0][1:sep_token]
        sent = tokenizer.decode(tokens_a.tolist())
    else:
        tokens_b = inputs["input_ids"][0][sep_token+1:-1]
        sent = tokenizer.decode(tokens_b.tolist())

    word_ids = get_word_idx(sent, word)

    # get all token idxs that belong to the word of interest
    token_ids_word = np.where(np.array(inputs.word_ids()) == word_ids)[0]
    sep_word = np.where(np.array(inputs.word_ids()) == None)[0][1]

    if ab == 'A':
        token_pos = token_ids_word < sep_word
    else:
        token_pos = token_ids_word > sep_word
        
    token_ids_word = token_ids_word[token_pos]
    word_tokens_output = outputs.hidden_states[layer][0][token_ids_word]

    # Change this to True for inspection
    details = True
    if details:
        input_id_list = input_ids[0].tolist() # Batch index 0
        tokens = tokenizer.convert_ids_to_tokens(input_id_list)
        str1 = " "

        print("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ")
        print("INPUT SEQUENCE TOKENS: ", str1.join(tokens))
        print("TARGET WORD:", word)
        print("TARGET SENTENCE:", ab)
        print("TARGET SENTENCE WORDS [", sent, "]")
        print("The word [", word, "] occurs in position", idx, "of the BERT input sentence", ab)
        print("The word [", word, "] corresponds to the token(s)", token_ids_word, "of the BERT input sequence", ab)

    return word_tokens_output.mean(dim=0)


word_embedding = get_word_vector_from_ab(inputs, outputs, "woman", 4, 'A')

word_embedding = get_word_vector_from_ab(inputs, outputs, "hypertension", 4, 'B')

In [None]:
word_embedding.shape

# Attention visualization

More details are available here: https://github.com/jessevig/bertviz

In [None]:
call_html()
head_view(attention, tokens)

In [None]:
model_view(attention, tokens)

## Other pre-trained BERT models

There are many other models available for download (https://huggingface.co/models).

BioBERT is a popular BERT model trained on biomedical literature (https://academic.oup.com/bioinformatics/article/36/4/1234/5566506):

    model_path = 'dmis-lab/biobert-v1.1'

Another popular BERT is the SciBERT trained on scientific literature (https://arxiv.org/abs/1903.10676):

    model_path = 'allenai/scibert_scivocab_uncased'

See above where the variable 'model_path' is defined.