diff --git a/server/.gitignore b/server/.gitignore index cc909312..aab79938 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -1,9 +1,8 @@ public/ static/ +conf/ digiscript.sqlite digiscript.json -conf/digiscript.sqlite -conf/digiscript.json # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/server/alembic_config/versions/d4f66f58158b_initial_alembic_revision.py b/server/alembic_config/versions/d4f66f58158b_initial_alembic_revision.py new file mode 100644 index 00000000..19ed2c28 --- /dev/null +++ b/server/alembic_config/versions/d4f66f58158b_initial_alembic_revision.py @@ -0,0 +1,26 @@ +"""Initial Alembic Revision + +Revision ID: d4f66f58158b +Revises: +Create Date: 2024-06-02 15:50:23.550851 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd4f66f58158b' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/server/digi_server/app_server.py b/server/digi_server/app_server.py index fdeb8931..2fb699bb 100644 --- a/server/digi_server/app_server.py +++ b/server/digi_server/app_server.py @@ -1,8 +1,12 @@ import os +import shutil +import time from typing import List, Optional -from alembic import command +import sqlalchemy +from alembic import command, script from alembic.config import Config +from alembic.runtime import migration from tornado.ioloop import IOLoop from tornado.web import Application, StaticFileHandler from tornado_prometheus import PrometheusMixIn @@ -20,6 +24,7 @@ from rbac.rbac import RBACController from utils.database import DigiSQLAlchemy from utils.env_parser import EnvParser +from utils.exceptions import DatabaseUpgradeRequired from utils.web.route import Route @@ -45,11 +50,11 @@ def __init__(self, debug=False, settings_path=None, skip_migrations=False, skip_ self._run_migrations() else: get_logger().warning('Skipping performing database migrations') - # And then check the database is up-to-date - if not skip_migrations_check: - self._check_migrations() - else: - get_logger().warning('Skipping database migrations check') + # And then check the database is up-to-date + if not skip_migrations_check: + self._check_migrations() + else: + get_logger().warning('Skipping database migrations check') # Finally, configure the database db_path = self.digi_settings.settings.get('db_path').get_value() get_logger().info(f'Using {db_path} as DB path') @@ -132,14 +137,34 @@ def _alembic_config(self): return alembic_cfg def _run_migrations(self): - get_logger().info('Running database migrations via Alembic') - # Run the upgrade on the database - command.upgrade(self._alembic_config, 'head') + try: + self._check_migrations() + except DatabaseUpgradeRequired: + get_logger().info('Running database migrations via Alembic') + # Create a copy of the database file as a backup before performing migrations + db_path: str = self.digi_settings.settings.get('db_path').get_value() + if db_path.startswith('sqlite:///'): + db_path = db_path.replace('sqlite:///', '') + if os.path.exists(db_path) and os.path.isfile(db_path): + get_logger().info('Creating copy of database file as backup') + new_file_name = f'{db_path}.{int(time.time())}' + shutil.copyfile(db_path, new_file_name) + get_logger().info(f'Created copy of database file as backup, saved to {new_file_name}') + else:gi + get_logger().warning('Database connection does not appear to be a file, cannot create backup!') + # Run the upgrade on the database + command.upgrade(self._alembic_config, 'head') + else: + get_logger().info('No database migrations to perform') def _check_migrations(self): get_logger().info('Checking database migrations via Alembic') - # Run the upgrade on the database - command.check(self._alembic_config) + engine = sqlalchemy.create_engine(self.digi_settings.settings.get('db_path').get_value()) + script_ = script.ScriptDirectory.from_config(self._alembic_config) + with engine.begin() as conn: + context = migration.MigrationContext.configure(conn) + if context.get_current_revision() != script_.get_current_head(): + raise DatabaseUpgradeRequired('Migrations required on the database') async def configure(self): await self._configure_logging() diff --git a/server/utils/exceptions.py b/server/utils/exceptions.py new file mode 100644 index 00000000..3c8af989 --- /dev/null +++ b/server/utils/exceptions.py @@ -0,0 +1,2 @@ +class DatabaseUpgradeRequired(Exception): + pass