Skip to content

GIGAMAN0/Attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Masked Word Prediction with BERT

This project involves developing an AI system to predict masked words in a text sequence using BERT (Bidirectional Encoder Representations from Transformers) and analyzing attention diagrams generated by the model.

Introduction

Masked Language Models (MLMs) like BERT are trained to predict masked words within a text sequence. BERT uses a transformer architecture and employs self-attention mechanisms to understand language. In this project, we utilize the transformers library developed by Hugging Face to implement a program that predicts masked words using BERT and generates attention diagrams.

Files

  • mask.py: Contains the implementation of functions to predict masked words and generate attention diagrams using BERT.
  • analysis.md: Documentation for analyzing attention diagrams and identifying relationships between words learned by attention heads.
  • requirements.txt: File listing dependencies required for the project.

Usage

  1. Download the distribution code from the provided link and unzip it.
  2. Navigate to the attention directory.
  3. Run the command pip3 install -r requirements.txt to install the required dependencies.
  4. Execute the command python mask.py to predict masked words and generate attention diagrams.

Background

The project leverages the BERT model to predict masked words in text sequences. The main function prompts the user for input text containing a mask token [MASK], which represents the word to be predicted. BERT predicts the masked word using self-attention mechanisms and generates attention diagrams to visualize attention scores for each token in the input sequence.

Implementation Details

  • Get Mask Token Index: The get_mask_token_index function retrieves the index of the mask token in the input sequence of tokens generated by the tokenizer.
  • Get Color for Attention Score: The get_color_for_attention_score function computes the RGB color values based on the attention score, mapping higher scores to lighter colors and lower scores to darker colors.
  • Visualize Attentions: The visualize_attentions function generates attention diagrams for all attention heads and layers. It extends the existing implementation to generate diagrams for each attention head and layer.

Analysis

The analysis.md file contains descriptions and examples of attention heads identified to have learned specific relationships between words. Two attention heads are analyzed, each demonstrating a unique relationship between tokens observed in multiple example sentences.

Conclusion

The project demonstrates the use of BERT for masked word prediction and attention visualization. By analyzing attention diagrams, we gain insights into the relationships learned by attention heads, contributing to our understanding of language comprehension in AI models.

About

AI to predict a masked word in a text sequence.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages