Skip to content

Commit

Permalink
fix wrapper constructor in ErrorAnalysisManager and RAIVisionInsights…
Browse files Browse the repository at this point in the history
… load call for RAI Vision Dashboard (microsoft#2560)
  • Loading branch information
imatiach-msft authored May 1, 2024
1 parent 6026326 commit 93eacce
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,21 @@ def _load(path, rai_insights):
feature_names = list(dataset.columns)
inst.__dict__['_feature_names'] = feature_names
task_type = rai_insights.task_type
wrapped_model = wrap_model(rai_insights.model, dataset,
task_type,
classes=rai_insights._classes,
device=rai_insights.device)
classes = rai_insights._classes
device = rai_insights.device

test = rai_insights.test
image_mode = rai_insights.image_mode
transformations = rai_insights._transformations
sample = test.iloc[0:2]
sample = get_images(sample, image_mode, transformations)
wrapped_model = wrap_model(
rai_insights.model, sample, task_type, classes=classes,
device=device)

inst.__dict__['_task_type'] = task_type
index_classes = rai_insights._classes
index_dataset = rai_insights.test
index_classes = classes
index_dataset = test
if isinstance(target_column, list):
# create copy of dataset as we will make modifications to it
index_dataset = index_dataset.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,9 @@ def load(path):
# load current state
RAIBaseInsights._load(
path, inst, manager_map, RAIVisionInsights._load_metadata)
inst._wrapped_model = wrap_model(inst.model, inst.test, inst.task_type,
sample = inst.test.iloc[0:2]
sample = get_images(sample, inst.image_mode, inst._transformations)
inst._wrapped_model = wrap_model(inst.model, sample, inst.task_type,
classes=inst._classes,
device=inst.device)
inst.automl_image_model = is_automl_image_model(inst._wrapped_model)
Expand Down
17 changes: 17 additions & 0 deletions responsibleai_vision/tests/common_vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ def create_dummy_model(df):
return DummyFlowersClassifier()


def create_raw_torchvision_classification_model():
"""Creates a dummy torchvision model for testing purposes.
:return: dummy torchvision model
:rtype: torchvision.models.resnet.ResNet
"""
return torchvision_models.vgg16(pretrained=False, num_classes=2)


def retrieve_unzip_file(download_url, data_file):
fetch_dataset(download_url, data_file)
# extract files
Expand Down Expand Up @@ -486,6 +495,14 @@ def _get_model_path(self, path):
return os.path.join(path, 'image-classification-model')


class TorchvisionDummyPipelineSerializer(object):
def save(self, model, path):
pass

def load(self, path):
return create_raw_torchvision_classification_model()


class ObjectDetectionPipelineSerializer(object):
def save(self, model, path):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from common_vision_utils import (DummyFlowersPipelineSerializer,
ImageClassificationPipelineSerializer,
ObjectDetectionPipelineSerializer,
TorchvisionDummyPipelineSerializer,
create_dummy_model,
create_image_classification_pipeline,
create_raw_torchvision_classification_model,
load_flowers_dataset,
load_fridge_object_detection_dataset,
load_imagenet_dataset, load_imagenet_labels,
Expand Down Expand Up @@ -49,6 +51,22 @@ def test_rai_insights_empty_save_load_save(self):
run_and_validate_serialization(
pred, test, task_type, class_names, label, serializer)

def test_rai_insights_pytorch_empty_save_load_save(self):
data = load_flowers_dataset(upscale=False)
data = data[0:1]
# stack two of the same image since we need same
# image sizes for pytorch model
data = data.append(data).reset_index(drop=True)
pred = create_raw_torchvision_classification_model()
test = data
class_names = data[ImageColumns.LABEL.value].unique()
task_type = ModelTask.IMAGE_CLASSIFICATION
label = ImageColumns.LABEL
serializer = TorchvisionDummyPipelineSerializer()

run_and_validate_serialization(
pred, test, task_type, class_names, label, serializer)

@pytest.mark.skip("Insufficient memory on test machines to load images")
def test_rai_insights_large_images_save_load_save(self):
PIL.Image.MAX_IMAGE_PIXELS = None
Expand Down

0 comments on commit 93eacce

Please sign in to comment.