top of page
Search

# Image Registration Using Perpective Transformation | Homography Transformation

This notebook shows the use of differentiable mutual information for image registration. Your main task is to change that registration cost function to normalizing gradient.

Import Libraries

```from skimage import io
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from skimage.transform import pyramid_gaussian

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)```

```#from google.colab import drive
#drive.mount('/content/drive')

I = io.imread("knee1.bmp").astype(np.float32)/255.0 # fixed image
J = io.imread("knee2.bmp").astype(np.float32)/255.0 # moving image

nChannel = 1 # gray scale image

%matplotlib inline
fig=plt.figure()
plt.imshow(I,cmap="gray")
plt.title("Fixed Image")
plt.imshow(J,cmap="gray")
plt.title("Moving Image")
plt.show()```

Output:

Regularization using multi-resolution decomposition of the motion vectors using Laplacian of Gaussian (LoG) pyramids

```L = 6 # Gaussian pyramid level
downscale = 2.0 # downscale factor for the gaussian pyramid
pyramid_I = tuple(pyramid_gaussian(I, downscale=downscale, multichannel=False))
pyramid_J = tuple(pyramid_gaussian(J, downscale=downscale, multichannel=False))

nFraction=0.1 # fraction of pixels used in MINE calculation

# create a list of necessary objects you will need and commit to GPU
I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]
for s in range(L):
I_, J_ = torch.tensor(pyramid_I[s].astype(np.float32)).to(device), torch.tensor(pyramid_J[s].astype(np.float32)).to(device)
I_lst.append(I_)
J_lst.append(J_)
h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]
ind_ = torch.randperm(h_*w_)[0:int(nFraction*h_*w_)].to(device)
ind_lst.append(ind_)

print(h_,w_,len(ind_))
h_lst.append(h_)
w_lst.append(w_)

y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])
y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0
xy_ = torch.stack([x_,y_],2)
xy_lst.append(xy_)```

Output:

```512 512 26214
256 256 6553
128 128 1638
64 64 409
32 32 102
16 16 25```

### The following cell implements homography transformation using matrix exponential.

```class HomographyNet(nn.Module):
def __init__(self):
super(HomographyNet, self).__init__()
# perspective transform basis matrices

self.B = torch.zeros(8,3,3).to(device)
self.B[0,0,2] = 1.0
self.B[1,1,2] = 1.0
self.B[2,0,1] = 1.0
self.B[3,1,0] = 1.0
self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0
self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0
self.B[6,2,0] = 1.0
self.B[7,2,1] = 1.0

# This function computes forward transform matrix
def forward(self):
return MatrixExp(self.B,self.v)

# This function computes inverse transform matrix
def inverse(self):
return MatrixExp(self.B,-self.v)

def MatrixExp(B,v):
C = torch.sum(B*v,0)
A = torch.eye(3).to(device)
H = torch.eye(3).to(device)
for i in torch.arange(1,10):
A = torch.mm(A/i,C)
H = H + A
return H
```

```n_neurons = 100
class MINE(nn.Module): #https://arxiv.org/abs/1801.04062
def __init__(self):
super(MINE, self).__init__()
self.fc1 = nn.Linear(2*nChannel, n_neurons)
self.fc2 = nn.Linear(n_neurons, n_neurons)
self.fc3 = nn.Linear(n_neurons, 1)

def forward(self, x, ind):

x = x.view(x.size()[0]*x.size()[1],x.size()[2])
z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))
for i in range(500):
ind_perm = ind[torch.randperm(len(ind))]
z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))
MI = torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))
return MI```

```def PerspectiveWarping(I, H, xv, yv):

# apply transformation in the homogeneous coordinates
xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0),align_corners=False).squeeze()
return J
```

```def multi_resolution_MINE_loss():
loss=0.0
for s in np.arange(L-1,-1,-1):
Jw_ = PerspectiveWarping(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()
mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])
loss = loss - (1./L)*mi
return loss
```

Complete a multi-resolution loss function in the cell below for Normalized Gradient Loss. Make sure that your loss function makes use of both the forward and the inverse transforms

```# Edit this cell to complete multi-resolution_loss function

# Complete this function by editing only the indicated area
# Do not change function definition 'def multi_resolution_loss():'
# Do not change the last line 'return loss'
def multi_resolution_loss():
loss=0.0

return loss```