Skip to content

Commit

Permalink
Version 0.3.2
Browse files Browse the repository at this point in the history
Actual hotfix.

Version 0.3.1 got yanked because I misspecified the version of
nanobind in pyproject.toml
  • Loading branch information
galv committed Apr 3, 2023
1 parent e01c7cb commit cb1a00d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = [
"wheel",
"cmake>=3.25",
"ninja",
"nanobind@git+https://github.com/galv/nanobind#egg=fix-batch-size-1-type-cast",
"nanobind@git+https://github.com/galv/nanobind@fix-batch-size-1-type-cast",
]

build-backend = "setuptools.build_meta"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def build_extension(self, ext: setuptools.extension.Extension):
setuptools.setup(
python_requires='>=3.7',
name='riva-asrlib-decoder',
version='0.3.1',
version='0.3.2',
author='NVIDIA',
author_email='dgalvez@nvidia.com',
keywords='ASR, CUDA, WFST, Decoder',
Expand Down
27 changes: 27 additions & 0 deletions src/riva/asrlib/decoder/test_graph_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,33 @@ def test_vanilla_ctc_topo_wer_mbr(self, nemo_model_name, dataset, expected_wer,
assert my_wer <= expected_wer + ERROR_MARGIN


def test_batch_size_1(self):
"""
Integration test for https://github.com/wjakob/nanobind/pull/162
"""
work_dir = os.path.join(self.temp_dir, "ctc")
nemo_model_name = "stt_en_conformer_ctc_small"

asr_model = nemo_asr.models.ASRModel.from_pretrained(
nemo_model_name, map_location=torch.device("cuda")
)

self.create_TLG("ctc", work_dir, nemo_model_name)

num_tokens_including_blank = len(asr_model.to_config_dict()["decoder"]["vocabulary"]) + 1

decoder_config = self.create_decoder_config()
decoder = BatchedMappedDecoderCuda(
decoder_config,
os.path.join(work_dir, "graph/graph_ctc_3-gram.pruned.3e-7/TLG.fst"),
os.path.join(work_dir, "graph/graph_ctc_3-gram.pruned.3e-7/words.txt"),
num_tokens_including_blank,
)

logits = torch.ones((1,100, num_tokens_including_blank), dtype=torch.float32).cuda()
lengths = torch.tensor([100], dtype=torch.int64).cpu()
decoder.decode_mbr(logits.detach(), lengths.detach())

# Note that nbest decoding tends to produce a worse WER than mbr
# decoding. This is expected.
@pytest.mark.parametrize(
Expand Down

0 comments on commit cb1a00d

Please sign in to comment.