Skip to content

Commit

Permalink
Merge pull request #212 from ssciwr/fix_195_central_target_display_du…
Browse files Browse the repository at this point in the history
…ration

Add `central_target_duration` to trial options
  • Loading branch information
lkeegan authored Jun 7, 2023
2 parents a8ffb0d + 6ab6a0e commit 07926fd
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/vstt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
"__version__",
]

__version__ = "0.28.0"
__version__ = "0.29.0"
4 changes: 2 additions & 2 deletions src/vstt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import logging
from typing import Any
from typing import Dict
from typing import Mapping
from typing import Type
from typing import TypeVar

Expand All @@ -30,7 +30,7 @@ def _has_valid_type(var: Any, correct_type: Type) -> bool:


def import_typed_dict(
input_dict: Dict, default_typed_dict: VsttTypedDict
input_dict: Mapping[str, Any], default_typed_dict: VsttTypedDict
) -> VsttTypedDict:
# start with a valid typed dict with default values
output_dict = copy.deepcopy(default_typed_dict)
Expand Down
19 changes: 8 additions & 11 deletions src/vstt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from vstt.stat import append_stats_data_to_excel
from vstt.stat import stats_dataframe
from vstt.trial import default_trial
from vstt.trial import import_trial
from vstt.trial import validate_trial
from vstt.trial import import_and_validate_trial


class Experiment:
Expand All @@ -35,8 +34,8 @@ def __init__(self, filename: Optional[str] = None):
self.load_file(filename)

def create_trialhandler(self) -> TrialHandlerExt:
for trial in self.trial_list:
validate_trial(trial)
for index, trial in enumerate(self.trial_list):
self.trial_list[index] = import_and_validate_trial(trial)
return TrialHandlerExt(
self.trial_list,
nReps=1,
Expand Down Expand Up @@ -147,11 +146,9 @@ def import_and_validate_dicts(
) -> None:
self.metadata = import_metadata(metadata_dict)
self.display_options = import_display_options(display_options_dict)
self.trial_list = []
for trial_dict in trial_dict_list:
trial = import_trial(trial_dict)
validate_trial(trial)
self.trial_list.append(trial)
self.trial_list = [
import_and_validate_trial(trial_dict) for trial_dict in trial_dict_list
]
self.trial_handler_with_results = None
self.stats = None
self.has_unsaved_changes = True
Expand All @@ -161,8 +158,8 @@ def import_and_validate_trial_handler(self, trial_handler: TrialHandlerExt) -> N
# psychopy trial handler converts empty trial list [] -> [None]
if trial_handler.trialList == [None]:
trial_handler.trialList = []
for trial in trial_handler.trialList:
validate_trial(trial)
for index, trial in enumerate(trial_handler.trialList):
trial_handler.trialList[index] = import_and_validate_trial(trial)
self.trial_list = trial_handler.trialList
if not trial_handler.extraInfo:
trial_handler.extraInfo = {}
Expand Down
4 changes: 2 additions & 2 deletions src/vstt/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def _get_dat(
) -> Any:
ar = data.get(key)
if ar is None:
logging.warning(
logging.debug(
f"Key '{key}' not found in data, using default value {default_value}"
)
return default_value
try:
return ar[index][i_target]
except IndexError:
logging.warning(
logging.debug(
f"Index error for key '{key}', index '{index}', i_target '{i_target}', using default value {default_value}"
)
return default_value
Expand Down
5 changes: 4 additions & 1 deletion src/vstt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> Non
if is_central_target:
target_size = trial["central_target_size"]
if not trial["fixed_target_intervals"]:
stop_target_time = t0 + trial["target_duration"]
if is_central_target:
stop_target_time = t0 + trial["central_target_duration"]
else:
stop_target_time = t0 + trial["target_duration"]
dist_correct = 1.0
# ensure we get at least one flip
should_continue_target = True
Expand Down
14 changes: 8 additions & 6 deletions src/vstt/trial.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import copy
from typing import Any
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -32,6 +34,7 @@ def default_trial() -> Trial:
"target_labels": "0 1 2 3 4 5 6 7",
"fixed_target_intervals": False,
"target_duration": 5.0,
"central_target_duration": 5.0,
"inter_target_duration": 1.0,
"target_distance": 0.4,
"target_size": 0.04,
Expand Down Expand Up @@ -67,6 +70,7 @@ def trial_labels() -> Dict:
"target_labels": "Target labels",
"fixed_target_intervals": "Fixed target display intervals",
"target_duration": "Target display duration (secs)",
"central_target_duration": "Central target display duration (secs)",
"inter_target_duration": "Delay between targets (secs)",
"target_distance": "Distance to targets (screen height fraction)",
"target_size": "Target size (screen height fraction)",
Expand Down Expand Up @@ -108,17 +112,15 @@ def get_trial_from_user(
)
if not dialog.OK:
return None
return validate_trial(trial)
return import_and_validate_trial(trial)


def import_trial(trial_dict: dict) -> Trial:
return import_typed_dict(trial_dict, default_trial())


def validate_trial(trial: Trial) -> Trial:
def import_and_validate_trial(trial_or_dict: Mapping[str, Any]) -> Trial:
trial = import_typed_dict(trial_or_dict, default_trial())
# make any negative time durations zero
for duration in [
"target_duration",
"central_target_duration",
"inter_target_duration",
"post_trial_delay",
"post_block_delay",
Expand Down
1 change: 1 addition & 0 deletions src/vstt/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Trial(TypedDict):
target_labels: str
fixed_target_intervals: bool
target_duration: float
central_target_duration: float
inter_target_duration: float
target_distance: float
target_size: float
Expand Down
17 changes: 15 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def experiment_no_results() -> Experiment:
trial0 = default_trial()
trial0["num_targets"] = 4
trial0["play_sound"] = False
trial0["target_duration"] = 30.0
trial0["target_duration"] = 60.0
trial0["central_target_duration"] = 60.0
trial0["inter_target_duration"] = 0.0
trial0["post_block_display_results"] = False
trial0["post_block_delay"] = 0.1
Expand All @@ -95,7 +96,7 @@ def experiment_with_results() -> Experiment:
experiment = Experiment()
# trial without auto-move to center, 3 reps, 8 targets
trial0 = default_trial()
# disable sounds due to issues with sounds within tests on linux
# disable sounds due to issues with sounds within tests on linux CI
trial0["play_sound"] = False
trial0["weight"] = 3
trial0["automove_cursor_to_center"] = False
Expand All @@ -114,7 +115,9 @@ def experiment_with_results() -> Experiment:
trial_handler = experiment.create_trialhandler()
for trial in trial_handler:
to_target_timestamps = []
to_target_num_timestamps_before_visible = []
to_center_timestamps = []
to_center_num_timestamps_before_visible = []
to_target_mouse_positions = []
to_center_mouse_positions = []
to_target_success = []
Expand All @@ -125,13 +128,15 @@ def experiment_with_results() -> Experiment:
t0 = 0.0
for pos in target_pos:
to_target_timestamps.append(make_timestamps(t0))
to_target_num_timestamps_before_visible.append(0)
to_target_mouse_positions.append(
make_mouse_positions(pos, to_target_timestamps[-1])
)
to_target_success.append(True)
t0 = to_target_timestamps[-1][-1] + 1.0 / 60.0
if not trial["automove_cursor_to_center"]:
to_center_timestamps.append(make_timestamps(t0))
to_center_num_timestamps_before_visible.append(0)
to_center_mouse_positions.append(
list(reversed(make_mouse_positions(pos, to_center_timestamps[-1])))
)
Expand All @@ -142,6 +147,10 @@ def experiment_with_results() -> Experiment:
trial_handler.addData(
"to_target_timestamps", np.array(to_target_timestamps, dtype=object)
)
trial_handler.addData(
"to_target_num_timestamps_before_visible",
np.array(to_target_num_timestamps_before_visible),
)
trial_handler.addData(
"to_target_mouse_positions",
np.array(to_target_mouse_positions, dtype=object),
Expand All @@ -154,6 +163,10 @@ def experiment_with_results() -> Experiment:
"to_center_mouse_positions",
np.array(to_center_mouse_positions, dtype=object),
)
trial_handler.addData(
"to_center_num_timestamps_before_visible",
np.array(to_target_num_timestamps_before_visible),
)
if trial["automove_cursor_to_center"]:
to_center_success = [True] * trial["num_targets"]
trial_handler.addData("to_center_success", np.array(to_center_success))
Expand Down
25 changes: 15 additions & 10 deletions tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_describe_trials() -> None:

def test_default_trial() -> None:
trial = vstt.trial.default_trial()
assert len(trial) == 30
assert len(trial) == len(vstt.trial.trial_labels())
assert isinstance(trial["target_indices"], str)
assert len(trial["target_indices"].split(" ")) == trial["num_targets"]

Expand All @@ -55,6 +55,7 @@ def test_import_trial() -> None:
"target_labels": "0 1 2 3 4 5",
"fixed_target_intervals": False,
"target_duration": 3,
"central_target_duration": 3,
"inter_target_duration": 0,
"target_distance": 0.3,
"target_size": 0.03,
Expand All @@ -78,7 +79,7 @@ def test_import_trial() -> None:
"enter_to_skip_delay": True,
}
# all valid keys are imported
trial = vstt.trial.import_trial(trial_dict)
trial = vstt.trial.import_and_validate_trial(trial_dict)
for key in trial:
assert trial[key] == trial_dict[key] # type: ignore
# if any keys are missing, default values are used instead
Expand All @@ -93,7 +94,7 @@ def test_import_trial() -> None:
# unknown keys are ignored
trial_dict["unknown_key1"] = "ignore me"
trial_dict["unknown_key2"] = False
trial = vstt.trial.import_trial(trial_dict)
trial = vstt.trial.import_and_validate_trial(trial_dict)
for key in trial:
if key in missing_keys:
assert trial[key] == default_trial[key] # type: ignore
Expand All @@ -105,21 +106,25 @@ def test_validate_trial_durations() -> None:
trial = vstt.trial.default_trial()
# positive durations are not modified
trial["target_duration"] = 1
trial["central_target_duration"] = 1
trial["inter_target_duration"] = 0.1
trial["post_trial_delay"] = 0.2
trial["post_block_delay"] = 0.7
vtrial = vstt.trial.validate_trial(trial)
vtrial = vstt.trial.import_and_validate_trial(trial)
assert vtrial["target_duration"] == 1
assert vtrial["central_target_duration"] == 1
assert vtrial["inter_target_duration"] == 0.1
assert vtrial["post_trial_delay"] == 0.2
assert vtrial["post_block_delay"] == 0.7
# negative durations are cast to zero
trial["target_duration"] = -1
trial["central_target_duration"] = -0.8
trial["inter_target_duration"] = -0.1
trial["post_trial_delay"] = -0.2
trial["post_block_delay"] = -0.7
vtrial = vstt.trial.validate_trial(trial)
vtrial = vstt.trial.import_and_validate_trial(trial)
assert vtrial["target_duration"] == 0
assert vtrial["central_target_duration"] == 0
assert vtrial["inter_target_duration"] == 0
assert vtrial["post_trial_delay"] == 0
assert vtrial["post_block_delay"] == 0
Expand All @@ -130,28 +135,28 @@ def test_validate_trial_target_order() -> None:
assert isinstance(trial["target_indices"], str)
# clockwise
trial["target_order"] = "clockwise"
vtrial = vstt.trial.validate_trial(trial)
vtrial = vstt.trial.import_and_validate_trial(trial)
assert isinstance(vtrial["target_indices"], str)
assert vtrial["target_indices"] == "0 1 2 3 4 5 6 7"
# anti-clockwise
trial["target_order"] = "anti-clockwise"
vtrial = vstt.trial.validate_trial(trial)
vtrial = vstt.trial.import_and_validate_trial(trial)
assert isinstance(vtrial["target_indices"], str)
assert vtrial["target_indices"] == "7 6 5 4 3 2 1 0"
# random
trial["target_order"] = "random"
vtrial = vstt.trial.validate_trial(trial)
vtrial = vstt.trial.import_and_validate_trial(trial)
assert isinstance(vtrial["target_indices"], str)
assert len(set(vtrial["target_indices"].split(" "))) == 8
# fixed & valid
trial["target_order"] = "fixed"
trial["target_indices"] = "0 1 2 3 4 5 6 7"
vtrial = vstt.trial.validate_trial(trial)
vtrial = vstt.trial.import_and_validate_trial(trial)
assert isinstance(vtrial["target_indices"], str)
assert vtrial["target_indices"] == "0 1 2 3 4 5 6 7"
# fixed & invalid - clipped to nearest valid indices
trial["target_order"] = "fixed"
trial["target_indices"] = "-2 8 1 5 12 -5"
vtrial = vstt.trial.validate_trial(trial)
vtrial = vstt.trial.import_and_validate_trial(trial)
assert isinstance(vtrial["target_indices"], str)
assert vtrial["target_indices"] == "0 7 1 5 7 0"

0 comments on commit 07926fd

Please sign in to comment.