Skip to content

Commit

Permalink
Merge pull request yl4579#13 from fumiama/mono
Browse files Browse the repository at this point in the history
feat: drop cython monotonic_align
  • Loading branch information
Stardust-minus authored Sep 2, 2023
2 parents 6173b3e + 508fae1 commit a4f9710
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 26,600 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

/models
1 change: 1 addition & 0 deletions bert/chinese-roberta-wwm-ext-large/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.bin
34 changes: 15 additions & 19 deletions monotonic_align/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import numpy as np
import torch
from .monotonic_align.core import maximum_path_c


def maximum_path(neg_cent, mask):
""" Cython optimized version.
neg_cent: [b, t_t, t_s]
mask: [b, t_t, t_s]
"""
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
path = np.zeros(neg_cent.shape, dtype=np.int32)

t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
maximum_path_c(path, neg_cent, t_t_max, t_s_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
from numpy import zeros, int32, float32
from torch import from_numpy

from .core import maximum_path_jit

def maximum_path(neg_cent, mask):
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
path = zeros(neg_cent.shape, dtype=int32)

t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
return from_numpy(path).to(device=device, dtype=dtype)
Loading

0 comments on commit a4f9710

Please sign in to comment.