Skip to content

Commit

Permalink
Code changes to maintain the lms_segment table up to date
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Nov 14, 2024
1 parent ade3996 commit 8c7a4af
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 18 deletions.
4 changes: 4 additions & 0 deletions lms/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from lms.services.organization_usage_report import OrganizationUsageReportService
from lms.services.roster import RosterService
from lms.services.rsa_key import RSAKeyService
from lms.services.segment import SegmentService
from lms.services.user import UserService
from lms.services.user_preferences import UserPreferencesService
from lms.services.vitalsource import VitalSourceService
Expand Down Expand Up @@ -156,6 +157,9 @@ def includeme(config): # noqa: PLR0915
config.register_service_factory(
"lms.services.group_set.factory", iface=GroupSetService
)
config.register_service_factory(
"lms.services.segment.factory", iface=SegmentService
)
config.register_service_factory(
"lms.services.auto_grading.factory", iface=AutoGradingService
)
Expand Down
2 changes: 2 additions & 0 deletions lms/services/grouping/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lms.services.grouping.service import GroupingService
from lms.services.segment import SegmentService


def service_factory(_context, request):
Expand All @@ -8,4 +9,5 @@ def service_factory(_context, request):
request.lti_user.application_instance if request.lti_user else None
),
plugin=request.product.plugin.grouping,
segment_service=request.find_service(SegmentService),
)
42 changes: 36 additions & 6 deletions lms/services/grouping/service.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from sqlalchemy import func
from sqlalchemy.orm import aliased

from lms.models import Course, Grouping, GroupingMembership, LTIUser, User
from lms.models import Course, Grouping, GroupingMembership, LTIRole, LTIUser, User
from lms.models._hashed_id import hashed_id
from lms.product.plugin.grouping import GroupingPlugin
from lms.services.segment import SegmentService
from lms.services.upsert import bulk_upsert


class GroupingService:
def __init__(self, db, application_instance, plugin: GroupingPlugin):
def __init__(
self,
db,
application_instance,
plugin: GroupingPlugin,
segment_service: SegmentService,
):
self._db = db
self.application_instance = application_instance
self.plugin = plugin
self.segment_service = segment_service

def get_authority_provided_id(
self, lms_id, type_: Grouping.Type, parent: Grouping | None = None
Expand Down Expand Up @@ -170,7 +178,13 @@ def get_sections(
else:
groupings = self.plugin.get_sections_for_instructor(self, course)

return self._to_groupings(user, groupings, course, self.plugin.sections_type)
return self._to_groupings(
user,
groupings,
course,
self.plugin.sections_type,
lti_roles=lti_user.lti_roles,
)

def get_groups( # noqa: PLR0913
self,
Expand Down Expand Up @@ -202,7 +216,13 @@ def get_groups( # noqa: PLR0913
self, course, group_set_id
)

return self._to_groupings(user, groupings, course, self.plugin.group_type)
return self._to_groupings(
user,
groupings,
course,
self.plugin.group_type,
lti_roles=lti_user.lti_roles,
)

def get_launch_grouping_type(self, request, course, assignment) -> Grouping.Type:
"""
Expand All @@ -224,7 +244,9 @@ def get_launch_grouping_type(self, request, course, assignment) -> Grouping.Type

return Grouping.Type.COURSE

def _to_groupings(self, user, groupings, course, type_):
def _to_groupings( # noqa: PLR0913
self, user, groupings, course, type_, lti_roles: list[LTIRole]
):
if groupings and not isinstance(groupings[0], Grouping):
groupings = [
{
Expand All @@ -241,7 +263,15 @@ def _to_groupings(self, user, groupings, course, type_):
for grouping in groupings
]
groupings = self.upsert_groupings(groupings, parent=course, type_=type_)
segments = self.segment_service.upsert_segments(
course=course.lms_course,
type_=groupings[0].type,
groupings=groupings,
lms_group_set_id=groupings[0].extra["group_set_id"],
)
self.segment_service.upsert_segment_memberships(
lms_user=user.lms_user, segments=segments, lti_roles=lti_roles
)

self.upsert_grouping_memberships(user, groupings)

return groupings
86 changes: 86 additions & 0 deletions lms/services/segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import TypedDict

from sqlalchemy import func

from lms.models import (
Grouping,
LMSCourse,
LMSSegment,
LMSSegmentMembership,
LMSUser,
LTIRole,
)
from lms.services.group_set import GroupSetService
from lms.services.upsert import bulk_upsert


class SegmentService:
def __init__(self, db, group_set_service: GroupSetService):
self._db = db
self._group_set_service = group_set_service

def upsert_segments(
self,
course: LMSCourse,
type_: Grouping.Type,
groupings: list[Grouping],
lms_group_set_id: str | None = None,
) -> list[LMSSegment]:
group_set = None
if lms_group_set_id:
group_set = self._group_set_service.find_group_set(
course.course.application_instance,
lms_id=lms_group_set_id,
context_id=course.lti_context_id,
)

return bulk_upsert(
self._db,
LMSSegment,
[
{
"type": type_,
"lms_id": segment.lms_id,
"name": segment.lms_name,
"h_authority_provided_id": segment.authority_provided_id,
"lms_course_id": course.id,
"lms_group_set_id": group_set.id if group_set else None,
}
for segment in groupings
],
index_elements=["h_authority_provided_id"],
update_columns=["name", "updated"],
).all()

def upsert_segment_memberships(
self,
lms_user: LMSUser,
lti_roles: list[LTIRole],
segments: list[LMSSegment],
) -> list[LMSSegmentMembership]:
if not lms_user.id or any(s.id is None for s in segments):
# Ensure all ORM objects have their PK populated
self._db.flush()

return bulk_upsert(
self._db,
LMSSegmentMembership,
[
{
"lms_segment_id": s.id,
"lms_user_id": lms_user.id,
"lti_role_id": lti_role.id,
"updated": func.now(),
}
for s in segments
for lti_role in lti_roles
],
index_elements=["lms_segment_id", "lms_user_id", "lti_role_id"],
update_columns=["updated"],
)


def factory(_context, request):
return SegmentService(
db=request.db, group_set_service=request.find_service(GroupSetService)
)
1 change: 1 addition & 0 deletions tests/factories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
LMSCourseMembership,
)
from tests.factories.lms_group_set import LMSGroupSet
from tests.factories.lms_segment import LMSSegment
from tests.factories.lms_user import LMSUser
from tests.factories.lti_registration import LTIRegistration
from tests.factories.lti_role import LTIRole, LTIRoleOverride
Expand Down
9 changes: 7 additions & 2 deletions tests/factories/lms_group_set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from factory import make_factory
from factory import Faker, make_factory
from factory.alchemy import SQLAlchemyModelFactory

from lms import models

LMSGroupSet = make_factory(models.LMSGroupSet, FACTORY_CLASS=SQLAlchemyModelFactory)
LMSGroupSet = make_factory(
models.LMSGroupSet,
FACTORY_CLASS=SQLAlchemyModelFactory,
lms_id=Faker("hexify", text="^" * 40),
name=Faker("word"),
)
14 changes: 14 additions & 0 deletions tests/factories/lms_segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import factory
from factory import Faker
from factory.alchemy import SQLAlchemyModelFactory

from lms import models

LMSSegment = factory.make_factory(
models.LMSSegment,
FACTORY_CLASS=SQLAlchemyModelFactory,
type=Faker("random_element", elements=models.Grouping.Type),
lms_id=Faker("hexify", text="^" * 40),
name=Faker("word"),
h_authority_provided_id=Faker("hexify", text="^" * 40),
)
5 changes: 4 additions & 1 deletion tests/unit/lms/services/grouping/factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

@pytest.mark.usefixtures("application_instance_service", "with_plugins")
class TestFactory:
def test_it(self, pyramid_request, application_instance, GroupingService):
def test_it(
self, pyramid_request, application_instance, GroupingService, segment_service
):
svc = service_factory(sentinel.context, pyramid_request)

GroupingService.assert_called_once_with(
db=pyramid_request.db,
application_instance=application_instance,
plugin=pyramid_request.product.plugin.grouping,
segment_service=segment_service,
)

assert svc == GroupingService.return_value
Expand Down
39 changes: 30 additions & 9 deletions tests/unit/lms/services/grouping/service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,15 @@ def test_get_groups_with_instructor(self, svc, lti_user, assert_groups_returned)
"group_set_key", ("groupSetId", "group_category_id", "group_set_id")
)
def test_to_groupings_with_dicts(
self, svc, upsert_groupings, upsert_grouping_memberships, group_set_key
self,
svc,
upsert_groupings,
upsert_grouping_memberships,
group_set_key,
lti_user,
):
user = factories.User()
course = factories.Course()
grouping_dicts = [
{
"id": sentinel.id,
Expand All @@ -413,7 +420,11 @@ def test_to_groupings_with_dicts(
]

groupings = svc._to_groupings( # noqa: SLF001
sentinel.user, grouping_dicts, sentinel.course, sentinel.grouping_type
user,
grouping_dicts,
course,
sentinel.grouping_type,
lti_roles=lti_user.lti_roles,
)

upsert_groupings.assert_called_once_with(
Expand All @@ -425,21 +436,25 @@ def test_to_groupings_with_dicts(
"settings": sentinel.settings,
}
],
parent=sentinel.course,
parent=course,
type_=sentinel.grouping_type,
)
upsert_grouping_memberships.assert_called_once_with(
sentinel.user, upsert_groupings.return_value
user, upsert_groupings.return_value
)
assert groupings == upsert_groupings.return_value

def test_to_groupings_when_already_groupings(
self, svc, upsert_groupings, upsert_grouping_memberships
self, svc, upsert_groupings, upsert_grouping_memberships, lti_user
):
groupings = factories.CanvasSection.create_batch(5)

svc._to_groupings( # noqa: SLF001
sentinel.user, groupings, sentinel.course, sentinel.grouping_type
sentinel.user,
groupings,
sentinel.course,
sentinel.grouping_type,
lti_roles=lti_user.lti_roles,
)

upsert_groupings.assert_not_called()
Expand Down Expand Up @@ -480,13 +495,14 @@ def assert_sections_returned(self, svc, assert_groupings_returned):
)

@pytest.fixture
def assert_groupings_returned(self, _to_groupings):
def assert_groupings_returned(self, _to_groupings, lti_user):
def assert_groupings_returned(groupings, plugin_method, grouping_type):
_to_groupings.assert_called_once_with(
sentinel.user,
plugin_method.return_value,
sentinel.course,
grouping_type,
lti_roles=lti_user.lti_roles,
)
assert groupings == _to_groupings.return_value

Expand Down Expand Up @@ -520,5 +536,10 @@ def user():


@pytest.fixture
def svc(db_session, application_instance, grouping_plugin):
return GroupingService(db_session, application_instance, plugin=grouping_plugin)
def svc(db_session, application_instance, grouping_plugin, segment_service):
return GroupingService(
db_session,
application_instance,
plugin=grouping_plugin,
segment_service=segment_service,
)
Loading

0 comments on commit 8c7a4af

Please sign in to comment.