Sidharth Shah

Honours Computing & Biomedical Engineering @ the University of Waterloo

Predicting Drug-Protein Binding Affinity with Graph Neural Networks

Introduction

Drug discovery is expensive and time-consuming. Before a drug candidate reaches clinical trials, researchers must screen thousands of molecules to find those that bind strongly to target proteins. This project explores using machine learning (specifically Graph Neural Networks aka GNNS) to predict binding affinity between proteins and small molecules, potentially accelerating this screening process.

The core question: given a protein sequence and a drug molecule, can we predict how tightly they'll bind?

The Dataset: BindingDB

I used BindingDB, a public database containing millions of measurements of protein-ligand binding affinities. Each entry includes:

  • Protein sequence: The amino acid chain (e.g., "MKTAYIAKQR...")
  • Ligand (drug) SMILES: A text representation of the molecule's structure
  • Kd value: Dissociation constant measuring binding strength (in nanomolar)

The raw dataset is massive (~2GB), so I processed it in chunks of 100,000 rows to avoid memory issues. After filtering for human proteins with valid Kd measurements and protein sequences, I had ~50,000 protein-ligand pairs to work with.

Important preprocessing detail: I converted Kd from nanomolar to molar, then took the negative log to get pKd = -log10(Kd). This makes the scale more interpretable. A higher pKd means stronger binding.

# From src/inspect_BindingDB.py
chunk['Kd_M'] = pd.to_numeric(chunk['Kd (nM)'], errors='coerce') * 1e-9
chunk['pKd'] = -np.log10(chunk['Kd_M'])

Critical Decision: How to Split the Data

Here's where many drug-binding prediction projects fail: the train/test split strategy.

The naive approach: Randomly split all protein-ligand pairs into train (70%), validation (15%), and test (15%). This seems reasonable, right?

The problem: Proteins appear multiple times in the dataset bound to different ligands. If the same protein appears in both training and test sets (with different drugs), your model learns that specific protein's binding properties during training, then "cheats" by recognizing it at test time. Your test performance looks great, but the model won't generalize to truly novel proteins.

The solution: Protein-level splitting. I split by unique proteins, ensuring any protein in the test set never appeared during training.

# From src/split_BindingDB.py
unique_proteins = df['BindingDB Target Chain Sequence 1'].unique()

# Split unique proteins (not individual rows)
train_proteins, temp_proteins = train_test_split(
    unique_proteins,
    test_size=0.3,
    random_state=42
)

val_proteins, test_proteins = train_test_split(
    temp_proteins,
    test_size=0.5,
    random_state=42
)

# Then filter dataframe by protein sets
train_df = df[df['BindingDB Target Chain Sequence 1'].isin(train_proteins)]

This is harder as your model must generalize to completely unseen proteins (but it's the only honest evaluation). This is called stratified splitting, and it's standard practice in bioinformatics ML.

Representing Proteins: ESM-2 Embeddings

Proteins are sequences of amino acids (e.g., "MKTAYIAKQR..."). Neural networks need fixed-size numerical vectors, not variable-length strings.

Enter ESM-2 (Evolutionary Scale Modeling), a protein language model from Meta AI. It's like BERT, but trained on millions of protein sequences instead of text. ESM-2 learned to predict masked amino acids, capturing biological patterns in the process.

I used the facebook/esm2_t30_150M_UR50D model (150 million parameters) to convert each protein sequence into a 640-dimensional embedding vector:

# From src/generate_protein_embeddings.py
from transformers import AutoTokenizer, AutoModel
import torch

modelVariation = "facebook/esm2_t30_150M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(modelVariation)
model = AutoModel.from_pretrained(modelVariation)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

def getProteinEmbedding(sequence, model, tokenizer, device):
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    hidden_states = outputs.last_hidden_state
    embedding = hidden_states.mean(dim=1)
    embedding = embedding.cpu().numpy().squeeze()
    return embedding

Each protein now becomes a 640-dimensional vector that captures its structural and functional properties. These embeddings are cached to disk, so I only compute them once.

Representing Ligands: Two Approaches

Ligands (drug molecules) are trickier. They're represented as SMILES strings (Simplified Molecular Input Line Entry System), a compact text notation for chemical structures.

Example SMILES: CC(=O)Oc1ccccc1C(=O)O (this is aspirin!)

I explored two representations:

1. Morgan Fingerprints (Baseline)

Morgan fingerprints are traditional cheminformatics features (2048-bit vectors) indicating the presence of molecular substructures:

# From src/generate_ligand_features.py
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

def getMFP(smiles, radius=2, n_bits=2048):
    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        return np.zeros(n_bits)

    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=n_bits)

    # Convert to numpy array
    arr = np.zeros(n_bits)
    for i in range(n_bits):
        arr[i] = fp[i]

    return arr

This gives a fixed 2048-dimensional vector per molecule. Simple and fast.

2. Molecular Graphs (GNN Approach)

GNNs can work directly with molecular graphs:

  • Nodes: Atoms (with features like element type, charge, hybridization)
  • Edges: Chemical bonds (with features like bond type, aromaticity)

Each molecule becomes a graph with variable numbers of atoms and bonds. This preserves the actual molecular structure rather than just substructure fingerprints.

# From src/generate_ligand_graphs.py
from rdkit import Chem

def getAtomFeatures(atom):
    features = []
    atomTypes = ['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Br', 'I']
    atomSymbol = atom.GetSymbol()
    for atomType in atomTypes:
        features.append(1 if atom.GetSymbol() == atomType else 0)
    features.append(1 if atomSymbol not in atomTypes else 0)

    features.append(atom.GetDegree())
    features.append(atom.GetFormalCharge())

    hybridizationTypes = [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2
    ]
    hyb = atom.GetHybridization()
    for hybType in hybridizationTypes:
        features.append(1 if hyb == hybType else 0)

    features.append(1 if atom.GetIsAromatic() else 0)
    features.append(atom.GetTotalNumHs())

    return features


def getBondFeatures(bond):
    features = []
    bondTypes = [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ]
    for bondType in bondTypes:
        features.append(1 if bond.GetBondType() == bondType else 0)

    features.append(1 if bond.GetIsConjugated() else 0)
    features.append(1 if bond.IsInRing() else 0)

    return features


def smilesToGraph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    nodeFeatures = []
    for atom in mol.GetAtoms():
        nodeFeatures.append(getAtomFeatures(atom))
    edgeIndices = []
    edgeFeatures = []

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edgeIndices.append((i, j))
        edgeIndices.append((j, i))
        bondFeatures = getBondFeatures(bond)
        edgeFeatures.append(bondFeatures)
        edgeFeatures.append(bondFeatures)

    return {
        'nodeFeatures': nodeFeatures,
        'edgeIndex': edgeIndices,
        'edgeFeatures': edgeFeatures,
        'numNodes': len(nodeFeatures),
    }

Model 1: XGBoost Baseline

Before diving into GNNs, I established a baseline using XGBoost (a popular gradient boosting algorithm).

Architecture: Concatenate protein embedding (640-dim) + ligand fingerprint (2048-dim) = 2688-dimensional input vector → XGBoost regressor → pKd prediction

Results:

  • Test RMSE: 1.441
  • Test R²: 0.163

This gives us a reference point. The model explains ~16% of variance in binding affinity (decent but could def be better).

Model 2: Graph Neural Network

GNNs can learn directly from molecular graphs instead of fixed fingerprints. The architecture has two parts:

Part 1: Ligand Graph Encoder

Converts variable-size molecular graph → 128-dim embedding

# From models/gnn.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphConvolutionLayer(nn.Module):

    def __init__(self, inFeatures, outFeatures):
        super(GraphConvolutionLayer, self).__init__()
        self.linear = nn.Linear(inFeatures, outFeatures)
        self.edgeLinear = nn.Linear(6, outFeatures)  # 6-dim edge features

    def forward(self, nodeFeatures, edgeIndex, edgeFeatures):
        numNodes = nodeFeatures.size(0)
        transformed = self.linear(nodeFeatures)
        aggregated = torch.zeros(numNodes, transformed.size(1), device=nodeFeatures.device)
        for idx, (src, dst) in enumerate(edgeIndex):
            edgeMessage = self.edgeLinear(edgeFeatures[idx])
            aggregated[dst] += transformed[src] + edgeMessage
        return F.relu(aggregated)

The GNN encoder stacks multiple graph convolution layers, letting information flow through the molecular graph. After 3-4 layers, we pool all atom embeddings (via mean pooling) to get a single 128-dim vector representing the entire molecule.

Part 2: Binding Affinity Predictor

Concatenate protein embedding (640-dim) + ligand graph embedding (128-dim) → MLP → pKd prediction

# From models/gnn.py
class BindingAffinityGNN(nn.Module):

    def __init__(self, proteinDimension=640, ligandGnnOutput=128, hiddenDimension=256):
        super(BindingAffinityGNN, self).__init__()
        self.ligandEncoder = GNNEncoder(
            nodeFeaturesDimension=19,
            hiddenDimension=128,
            outputDimension=ligandGnnOutput,
            numLayers=3
        )
        combinedDimension = proteinDimension + ligandGnnOutput
        self.mlp = nn.Sequential(
            nn.Linear(combinedDimension, hiddenDimension),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hiddenDimension, hiddenDimension // 2),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hiddenDimension // 2, 1)
        )

    def forward(self, proteinEmbedding, nodeFeatures, edgeIndex, edgeFeatures):
        ligandEmbedding = self.ligandEncoder(nodeFeatures, edgeIndex, edgeFeatures)
        combined = torch.cat([proteinEmbedding, ligandEmbedding], dim=0)
        prediction = self.mlp(combined)

        return prediction

Model Iterations & Results

I trained several GNN variants, tweaking architecture and regularization:

ModelTest RMSETest R²Key Changes
XGBoost1.4410.163Baseline (concatenated features)
GNN1.5100.081Initial 4-layer GNN
GANN1.5060.086Added attention pooling (overfitted)
GNN Iter31.4250.1824 layers + residual connections + batch norm
GNN Iter41.4640.1373 layers + residuals
GNN Iter51.4950.1003 layers + gradient clipping

Best performer: GNN Iter3 with RMSE of 1.425 and R² of 0.182.

What Worked:

  • Residual connections: Help gradients flow through deep GNN layers
  • Batch normalization: Stabilized training
  • Dropout (0.2-0.3): Prevented overfitting

What Didn't Work:

  • Attention pooling (GANN): Overfitted-validation loss diverged from training
  • Too many layers: Diminishing returns beyond 4 GNN layers
  • Aggressive gradient clipping: Slowed learning without improving generalization

Visualizing Results

Model Comparison

RMSE Comparison Lower RMSE is better. GNN Iter3 achieves the lowest error R2 Comparison R2 indicates the variance. Higher is better

Data Distribution

Data Distribution Illustration of data

Key Takeaways

  1. GNNs are effective for molecular data: GNN Iter3 outperformed the XGBoost baseline, showing that graph structure carries useful information beyond molecular fingerprints.

  2. Proper data splitting is crucial: Protein-level splitting prevents data leakage but makes the problem harder. Random splitting would give inflated performance that doesn't generalize.

  3. Architecture choices matter: Residual connections and batch normalization made the difference between GNN Iter3 (best) and the original GNN (worst performer).

  4. Modest improvements, hard problem: Even the best model (R² = 0.182) explains less than 20% of variance. Binding affinity depends on many factors beyond sequence and structure, such as 3D protein conformation, solvent effects, temperature, etc. This is a genuinely difficult prediction task.

  5. Transfer learning helps: Using pre-trained ESM-2 for protein embeddings was essential. Training protein encoders from scratch would require far more data.

Future Directions

  • 3D protein structure: Incorporate AlphaFold-predicted structures instead of just sequences
  • Attention mechanisms: Properly implemented cross-attention between protein and ligand
  • Larger datasets: BindingDB has millions of entries (I only used ~50k)
  • Multi-task learning: Predict other properties (Ki, IC50) jointly
  • Uncertainty quantification: Add probabilistic predictions to flag uncertain cases

Conclusion

Predicting drug-protein binding affinity is challenging, but Graph Neural Networks show promise for learning from molecular structure. The project reinforced that successful ML in drug discovery requires both sophisticated models AND rigorous experimental design. Fancy architectures won't save you from data leakage.

The code and visualizations are available on GitHub.