Coherent Point Drift Implementation in pytorch version
git clone https://github.com/mikami520/CPD-Pytorch.git
cd CPD-Pytorch
pip install -e .
from functools import partial
import matplotlib.pyplot as plt
from torchcpd import RigidRegistration
import numpy as np
import torch as th
device = 'cuda:0' if th.cuda.is_available() else 'cpu'
X = np.loadtxt('data/bunny_target.txt')
# synthetic data, equaivalent to X + 1
Y = np.loadtxt('data/bunny_source.txt')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
callback = partial(visualize, ax=ax, fig=fig, save_fig=False)
reg = RigidRegistration(**{'X': X, 'Y': Y, 'device': device})
reg.register(callback)
plt.show()
More tutorials can be found in the /examples
folder.