diff --git a/responsibleai_vision/responsibleai_vision/utils/feature_extractors.py b/responsibleai_vision/responsibleai_vision/utils/feature_extractors.py index 33e66b317a..3080ecbc96 100644 --- a/responsibleai_vision/responsibleai_vision/utils/feature_extractors.py +++ b/responsibleai_vision/responsibleai_vision/utils/feature_extractors.py @@ -12,13 +12,15 @@ from tqdm import tqdm from responsibleai.feature_metadata import FeatureMetadata -from responsibleai_vision.common.constants import ExtractedFeatures +from responsibleai_vision.common.constants import (ExtractedFeatures, + ImageColumns) from responsibleai_vision.utils.image_reader import ( get_all_exif_feature_names, get_image_from_path, get_image_pointer_from_path) MEAN_PIXEL_VALUE = ExtractedFeatures.MEAN_PIXEL_VALUE.value MAX_CUSTOM_LEN = 100 +IMAGE_DETAILS = ImageColumns.IMAGE_DETAILS.value def extract_features(image_dataset: pd.DataFrame, @@ -58,6 +60,8 @@ def extract_features(image_dataset: pd.DataFrame, start_meta_index = 2 if isinstance(target_column, list): start_meta_index = len(target_column) + 1 + if IMAGE_DETAILS in column_names: + start_meta_index += 1 for j in range(start_meta_index, image_dataset.shape[1]): if has_dropped_features and column_names[j] in dropped_features: continue diff --git a/responsibleai_vision/tests/rai_vision_insights_validator.py b/responsibleai_vision/tests/rai_vision_insights_validator.py index 5af4af1b84..a0832706d9 100644 --- a/responsibleai_vision/tests/rai_vision_insights_validator.py +++ b/responsibleai_vision/tests/rai_vision_insights_validator.py @@ -31,6 +31,13 @@ def validate_rai_vision_insights( pd.testing.assert_frame_equal(rai_vision_test, test_data) assert rai_vision_insights.target_column == target_column assert rai_vision_insights.task_type == task_type + # make sure label column not in _ext_test extracted features data + assert target_column not in rai_vision_insights._ext_features + # also not in last column of _ext_test, which is prone to happen + # if incorrect number of metadata columns specified in + # feature_extractors call + first_row = rai_vision_insights._ext_test[0] + assert not isinstance(first_row[len(first_row) - 1], list) def run_and_validate_serialization( diff --git a/responsibleai_vision/tests/test_feature_extractors.py b/responsibleai_vision/tests/test_feature_extractors.py index 94d72f0780..46849af024 100644 --- a/responsibleai_vision/tests/test_feature_extractors.py +++ b/responsibleai_vision/tests/test_feature_extractors.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation # Licensed under the MIT License. +import pytest from common_vision_utils import (load_flowers_dataset, load_fridge_dataset, load_fridge_object_detection_dataset, load_imagenet_dataset) @@ -39,8 +40,10 @@ def extract_dataset_features(data, feature_metadata=None): class TestFeatureExtractors(object): - def test_extract_features_fridge_object_detection(self): - data = load_fridge_object_detection_dataset(automl_format=False) + @pytest.mark.parametrize("automl_format", [True, False]) + def test_extract_features_fridge_object_detection(self, automl_format): + data = load_fridge_object_detection_dataset( + automl_format=automl_format) extracted_features, feature_names = extract_dataset_features(data) expected_feature_names = [MEAN_PIXEL_VALUE] + FRIDGE_METADATA_FEATURES validate_extracted_features(extracted_features, feature_names,