Skip to content

Commit

Permalink
Sticking the job cache into the flask app's config, in an effort to
Browse files Browse the repository at this point in the history
improve/fix things.

Signed-off-by: Cliff Hill <Clifford.hill@gsa.gov>
  • Loading branch information
xlorepdarkhelm committed Nov 5, 2024
1 parent 54222f3 commit 5a9b867
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
4 changes: 4 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import uuid
from contextlib import contextmanager
from multiprocessing import Manager
from time import monotonic

from celery import Celery, Task, current_task
Expand Down Expand Up @@ -119,6 +120,9 @@ def create_app(application):
redis_store.init_app(application)
document_download_client.init_app(application)

manager = Manager()
application.config["job_cache"] = manager.dict()

register_blueprint(application)

# avoid circular imports by importing this file later
Expand Down
42 changes: 25 additions & 17 deletions app/aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
import time
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Manager

import botocore
from boto3 import Session
Expand All @@ -16,20 +15,30 @@

# Temporarily extend cache to 7 days
ttl = 60 * 60 * 24 * 7
manager = Manager()
job_cache = manager.dict()


# Global variable
s3_client = None
s3_resource = None


def set_job_cache(job_cache, key, value):
def set_job_cache(key, value):
job_cache = current_app.config["job_cache"]
job_cache[key] = (value, time.time() + 8 * 24 * 60 * 60)


def get_job_cache(key):
job_cache = current_app.config["job_cache"]
return job_cache.get(key)


def len_job_cache():
job_cache = current_app.config["job_cache"]
return len(job_cache)


def clean_cache():
job_cache = current_app.config["job_cache"]
current_time = time.time()
keys_to_delete = []
for key, (_, expiry_time) in job_cache.items():
Expand Down Expand Up @@ -162,17 +171,16 @@ def read_s3_file(bucket_name, object_key, s3res):
"""
try:
job_id = get_job_id_from_s3_object_key(object_key)
if job_cache.get(job_id) is None:
if get_job_cache(job_id) is None:
object = (
s3res.Object(bucket_name, object_key)
.get()["Body"]
.read()
.decode("utf-8")
)
set_job_cache(job_cache, job_id, object)
set_job_cache(job_cache, f"{job_id}_phones", extract_phones(object))
set_job_cache(job_id, object)
set_job_cache(f"{job_id}_phones", extract_phones(object))
set_job_cache(
job_cache,
f"{job_id}_personalisation",
extract_personalisation(object),
)
Expand All @@ -192,7 +200,7 @@ def get_s3_files():

s3res = get_s3_resource()
current_app.logger.info(
f"job_cache length before regen: {len(job_cache)} #notify-admin-1200"
f"job_cache length before regen: {len_job_cache()} #notify-admin-1200"
)
try:
with ThreadPoolExecutor() as executor:
Expand All @@ -201,7 +209,7 @@ def get_s3_files():
current_app.logger.exception("Connection pool issue")

current_app.logger.info(
f"job_cache length after regen: {len(job_cache)} #notify-admin-1200"
f"job_cache length after regen: {len_job_cache()} #notify-admin-1200"
)


Expand Down Expand Up @@ -424,12 +432,12 @@ def extract_personalisation(job):


def get_phone_number_from_s3(service_id, job_id, job_row_number):
job = job_cache.get(job_id)
job = get_job_cache(job_id)
if job is None:
current_app.logger.info(f"job {job_id} was not in the cache")
job = get_job_from_s3(service_id, job_id)
# Even if it is None, put it here to avoid KeyErrors
set_job_cache(job_cache, job_id, job)
set_job_cache(job_id, job)
else:
# skip expiration date from cache, we don't need it here
job = job[0]
Expand All @@ -441,7 +449,7 @@ def get_phone_number_from_s3(service_id, job_id, job_row_number):
return "Unavailable"

phones = extract_phones(job)
set_job_cache(job_cache, f"{job_id}_phones", phones)
set_job_cache(f"{job_id}_phones", phones)

# If we can find the quick dictionary, use it
phone_to_return = phones[job_row_number]
Expand All @@ -458,12 +466,12 @@ def get_personalisation_from_s3(service_id, job_id, job_row_number):
# We don't want to constantly pull down a job from s3 every time we need the personalisation.
# At the same time we don't want to store it in redis or the db
# So this is a little recycling mechanism to reduce the number of downloads.
job = job_cache.get(job_id)
job = get_job_cache(job_id)
if job is None:
current_app.logger.info(f"job {job_id} was not in the cache")
job = get_job_from_s3(service_id, job_id)
# Even if it is None, put it here to avoid KeyErrors
set_job_cache(job_cache, job_id, job)
set_job_cache(job_id, job)
else:
# skip expiration date from cache, we don't need it here
job = job[0]
Expand All @@ -478,9 +486,9 @@ def get_personalisation_from_s3(service_id, job_id, job_row_number):
)
return {}

set_job_cache(job_cache, f"{job_id}_personalisation", extract_personalisation(job))
set_job_cache(f"{job_id}_personalisation", extract_personalisation(job))

return job_cache.get(f"{job_id}_personalisation")[0].get(job_row_number)
return get_job_cache(f"{job_id}_personalisation")[0].get(job_row_number)


def get_job_metadata_from_s3(service_id, job_id):
Expand Down
15 changes: 6 additions & 9 deletions tests/app/aws/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_cleanup_old_s3_objects(mocker):
mock_remove_csv_object.assert_called_once_with("A")


def test_read_s3_file_success(mocker):
def test_read_s3_file_success(client, mocker):
mock_s3res = MagicMock()
mock_extract_personalisation = mocker.patch("app.aws.s3.extract_personalisation")
mock_extract_phones = mocker.patch("app.aws.s3.extract_phones")
Expand All @@ -89,16 +89,13 @@ def test_read_s3_file_success(mocker):
mock_extract_phones.return_value = ["1234567890"]
mock_extract_personalisation.return_value = {"name": "John Doe"}

global job_cache
job_cache = {}

read_s3_file(bucket_name, object_key, mock_s3res)
mock_get_job_id.assert_called_once_with(object_key)
mock_s3res.Object.assert_called_once_with(bucket_name, object_key)
expected_calls = [
call(ANY, job_id, file_content),
call(ANY, f"{job_id}_phones", ["1234567890"]),
call(ANY, f"{job_id}_personalisation", {"name": "John Doe"}),
call(job_id, file_content),
call(f"{job_id}_phones", ["1234567890"]),
call(f"{job_id}_personalisation", {"name": "John Doe"}),
]
mock_set_job_cache.assert_has_calls(expected_calls, any_order=True)

Expand Down Expand Up @@ -380,9 +377,9 @@ def test_file_exists_false(notify_api, mocker):
get_s3_mock.assert_called_once()


def test_get_s3_files_success(notify_api, mocker):
def test_get_s3_files_success(client, mocker):
mock_current_app = mocker.patch("app.aws.s3.current_app")
mock_current_app.config = {"CSV_UPLOAD_BUCKET": {"bucket": "test-bucket"}}
mock_current_app.config = {"CSV_UPLOAD_BUCKET": {"bucket": "test-bucket"}, "job_cache": {}}
mock_thread_pool_executor = mocker.patch("app.aws.s3.ThreadPoolExecutor")
mock_read_s3_file = mocker.patch("app.aws.s3.read_s3_file")
mock_list_s3_objects = mocker.patch("app.aws.s3.list_s3_objects")
Expand Down

0 comments on commit 5a9b867

Please sign in to comment.