top of page

Image Registration Using Perpective Transformation | Homography Transformation

This notebook shows the use of differentiable mutual information for image registration. Your main task is to change that registration cost function to normalizing gradient.

Import Libraries

from skimage import io
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from skimage.transform import pyramid_gaussian

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Load Images From Google Drive

#from google.colab import drive

I = io.imread("knee1.bmp").astype(np.float32)/255.0 # fixed image
J = io.imread("knee2.bmp").astype(np.float32)/255.0 # moving image

nChannel = 1 # gray scale image

%matplotlib inline
plt.title("Fixed Image")
plt.title("Moving Image")


Regularization using multi-resolution decomposition of the motion vectors using Laplacian of Gaussian (LoG) pyramids

L = 6 # Gaussian pyramid level
downscale = 2.0 # downscale factor for the gaussian pyramid
pyramid_I = tuple(pyramid_gaussian(I, downscale=downscale, multichannel=False))
pyramid_J = tuple(pyramid_gaussian(J, downscale=downscale, multichannel=False))

nFraction=0.1 # fraction of pixels used in MINE calculation

# create a list of necessary objects you will need and commit to GPU
for s in range(L):
  I_, J_ = torch.tensor(pyramid_I[s].astype(np.float32)).to(device), torch.tensor(pyramid_J[s].astype(np.float32)).to(device)
  h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]
  ind_ = torch.randperm(h_*w_)[0:int(nFraction*h_*w_)].to(device)


  y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])
  y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0
  xy_ = torch.stack([x_,y_],2)


512 512 26214
256 256 6553
128 128 1638
64 64 409
32 32 102
16 16 25

The following cell implements homography transformation using matrix exponential.

class HomographyNet(nn.Module):
  def __init__(self):
    super(HomographyNet, self).__init__()
    # perspective transform basis matrices

    self.B = torch.zeros(8,3,3).to(device)
    self.B[0,0,2] = 1.0
    self.B[1,1,2] = 1.0
    self.B[2,0,1] = 1.0
    self.B[3,1,0] = 1.0
    self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0
    self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0
    self.B[6,2,0] = 1.0
    self.B[7,2,1] = 1.0

    self.v = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)

  # This function computes forward transform matrix
  def forward(self):
    return MatrixExp(self.B,self.v)

  # This function computes inverse transform matrix
  def inverse(self):
    return MatrixExp(self.B,-self.v)

def MatrixExp(B,v):
    C = torch.sum(B*v,0)
    A = torch.eye(3).to(device)
    H = torch.eye(3).to(device)
    for i in torch.arange(1,10):
        A =,C)
        H = H + A
    return H    

n_neurons = 100
class MINE(nn.Module): #
  def __init__(self):
    super(MINE, self).__init__()
    self.fc1 = nn.Linear(2*nChannel, n_neurons)
    self.fc2 = nn.Linear(n_neurons, n_neurons)
    self.fc3 = nn.Linear(n_neurons, 1)

  def forward(self, x, ind):
    x = x.view(x.size()[0]*x.size()[1],x.size()[2])
    z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))
    for i in range(500):
      ind_perm = ind[torch.randperm(len(ind))]
      z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1([ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))
      MI = torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))
    return MI

def PerspectiveWarping(I, H, xv, yv):

  # apply transformation in the homogeneous coordinates
  xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
  yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
  J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0),align_corners=False).squeeze()
  return J

def multi_resolution_MINE_loss():
  for s in np.arange(L-1,-1,-1):
    Jw_ = PerspectiveWarping(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()
    mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])
    loss = loss - (1./L)*mi
  return loss

Complete a multi-resolution loss function in the cell below for Normalized Gradient Loss. Make sure that your loss function makes use of both the forward and the inverse transforms

# Edit this cell to complete multi-resolution_loss function

# Complete this function by editing only the indicated area
# Do not change function definition 'def multi_resolution_loss():'
# Do not change the last line 'return loss'
def multi_resolution_loss():
  # write your code here
  return loss

Contact Us!

And get instant help with our image processing expert.


bottom of page