Skip to content

Commit

Permalink
Fix bugs and name
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshitomo-matsubara committed Oct 28, 2024
1 parent d004695 commit 530489d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
12 changes: 5 additions & 7 deletions sc2bench/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,7 @@ def __init__(self, bottleneck_layer, short_module_names, inception_v3_model, ski
module_dict['maxpool2'] = nn.MaxPool2d(kernel_size=3, stride=2)
child_name_list.append('maxpool2')
elif child_name == 'fc':
module_dict['adaptive_avgpool'] = nn.AdaptiveAvgPool2d((1, 1))
module_dict['dropout'] = nn.Dropout()
module_dict['flatten'] = nn.Flatten(1)
break

module_dict[child_name] = child_module
child_name_list.append(child_name)
Expand All @@ -436,7 +434,7 @@ def forward(self, x):
x = self.bottleneck_layer(x)

x = self.inception_modules(x)
if self.adaptive_avgpool is None:
if self.avgpool is None:
return x

x = self.avgpool(x)
Expand Down Expand Up @@ -727,9 +725,9 @@ def splittable_densenet(bottleneck_config, densenet_name='densenet169', short_fe


@register_backbone_func
def splittable_inception3(bottleneck_config, short_module_names=None, skips_avgpool=True, skips_dropout=True,
skips_fc=True, pre_transform=None, analysis_config=None,
org_model_ckpt_file_path_or_url=None, org_ckpt_strict=True, **inception_v3_kwargs):
def splittable_inception_v3(bottleneck_config, short_module_names=None, skips_avgpool=True, skips_dropout=True,
skips_fc=True, pre_transform=None, analysis_config=None,
org_model_ckpt_file_path_or_url=None, org_ckpt_strict=True, **inception_v3_kwargs):
"""
Builds InceptionV3-based splittable image classification model optionally containing neural encoder,
entropy bottleneck, and decoder.
Expand Down
4 changes: 2 additions & 2 deletions sc2bench/models/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def larger_densenet_bottleneck(bottleneck_channel=12, bottleneck_idx=7,


@register_layer_func
def larger_inception_v3_bottleneck(bottleneck_channel=12, bottleneck_idx=7,
compressor_transform=None, decompressor_transform=None):
def inception_v3_bottleneck(bottleneck_channel=12, bottleneck_idx=7,
compressor_transform=None, decompressor_transform=None):
"""
Builds a bottleneck layer InceptionV3-based encoder and decoder (17 layers in total).
Expand Down

0 comments on commit 530489d

Please sign in to comment.