diff --git a/src/contracts/Gateway/AbstractDoctrineDatabase.php b/src/contracts/Gateway/AbstractDoctrineDatabase.php index 35d1ac5..8a9a94d 100644 --- a/src/contracts/Gateway/AbstractDoctrineDatabase.php +++ b/src/contracts/Gateway/AbstractDoctrineDatabase.php @@ -125,18 +125,15 @@ public function countAll(): int public function countBy($criteria): int { $metadata = $this->getMetadata(); - $qb = $this->connection->createQueryBuilder(); + $qb = $this->createBaseQueryBuilder(); $this->applyInheritance($qb); $identifierColumn = $metadata->getIdentifierColumn(); $tableAlias = $this->getTableAlias(); - $qb->select($this->connection->getDatabasePlatform()->getCountExpression($tableAlias . '.' . $identifierColumn)); - $qb->from($metadata->getTableName(), $tableAlias); + $platform = $this->connection->getDatabasePlatform(); + $qb->select($platform->getCountExpression(sprintf('DISTINCT %s.%s', $tableAlias, $identifierColumn))); - $expr = $this->convertCriteriaToExpression($qb, $criteria); - if ($expr !== null) { - $qb->andWhere($expr); - } + $this->applyCriteria($qb, $criteria); return (int)$qb->execute()->fetchOne(); } @@ -162,39 +159,10 @@ public function findBy($criteria, ?array $orderBy = null, ?int $limit = null, in $qb = $this->createBaseQueryBuilder(); $this->applyInheritance($qb); - foreach ($orderBy ?? [] as $column => $order) { - if (!$metadata->hasColumn($column) && !$metadata->isInheritedColumn($column)) { - $columns = $metadata->getColumns(); - foreach ($metadata->getSubclasses() as $subMetadata) { - $columns = array_merge($columns, $subMetadata->getColumns()); - } - - throw new InvalidArgumentException(sprintf( - '"%s" does not exist in "%s", or is not available for ordering. Available columns are: "%s"', - $column, - $this->getTableName(), - implode('", "', $columns), - )); - } + $this->applyOrderBy($qb, $orderBy); + $this->applyCriteria($qb, $criteria); + $this->applyLimits($qb, $limit, $offset); - if ($metadata->isInheritedColumn($column)) { - $subMetadata = $metadata->getInheritanceMetadataWithColumn($column); - assert($subMetadata !== null); - $qb->addOrderBy($subMetadata->getTableName() . '.' . $column, $order); - } else { - $qb->addOrderBy($this->getTableAlias() . '.' . $column, $order); - } - } - - $expr = $this->convertCriteriaToExpression($qb, $criteria); - if ($expr !== null) { - $qb->andWhere($expr); - } - - if ($limit !== null) { - $qb->setMaxResults($limit); - } - $qb->setFirstResult($offset); $results = $qb->execute()->fetchAllAssociative(); return array_map([$metadata, 'convertToPHPValues'], $results); @@ -427,4 +395,55 @@ private function buildCondition(QueryBuilder $qb, string $column, $value): strin return $qb->expr()->eq($fullColumnName, $parameter); } + + /** + * @param array|null $orderBy Map of column names to "ASC" or "DESC", that will be used in SORT query + */ + final protected function applyOrderBy(QueryBuilder $qb, ?array $orderBy = []): void + { + $metadata = $this->getMetadata(); + + foreach ($orderBy ?? [] as $column => $order) { + if (!$metadata->hasColumn($column) && !$metadata->isInheritedColumn($column)) { + $columns = $metadata->getColumns(); + foreach ($metadata->getSubclasses() as $subMetadata) { + $columns = array_merge($columns, $subMetadata->getColumns()); + } + + throw new InvalidArgumentException(sprintf( + '"%s" does not exist in "%s", or is not available for ordering. Available columns are: "%s"', + $column, + $this->getTableName(), + implode('", "', $columns), + )); + } + + if ($metadata->isInheritedColumn($column)) { + $subMetadata = $metadata->getInheritanceMetadataWithColumn($column); + assert($subMetadata !== null); + $qb->addOrderBy($subMetadata->getTableName() . '.' . $column, $order); + } else { + $qb->addOrderBy($this->getTableAlias() . '.' . $column, $order); + } + } + } + + /** + * @param \Doctrine\Common\Collections\Expr\Expression|array|null> $criteria + */ + final protected function applyCriteria(QueryBuilder $qb, $criteria): void + { + $expr = $this->convertCriteriaToExpression($qb, $criteria); + if ($expr !== null) { + $qb->andWhere($expr); + } + } + + final protected function applyLimits(QueryBuilder $qb, ?int $limit, int $offset): void + { + if ($limit !== null) { + $qb->setMaxResults($limit); + } + $qb->setFirstResult($offset); + } } diff --git a/src/lib/Gateway/ExpressionVisitor.php b/src/lib/Gateway/ExpressionVisitor.php index 11b8dcd..9aabf6b 100644 --- a/src/lib/Gateway/ExpressionVisitor.php +++ b/src/lib/Gateway/ExpressionVisitor.php @@ -252,13 +252,13 @@ private function handleRelationshipComparison(string $column, Comparison $compar $metadata, $relationship->getForeignKeyColumn(), $column, - $comparison->getValue(), + $comparison, ); case $relationship->getJoinType() === DoctrineRelationship::JOIN_TYPE_JOINED: return $this->handleJoinQuery( $metadata, $column, - $comparison->getValue(), + $comparison, ); default: throw new RuntimeMappingException(sprintf( @@ -287,13 +287,13 @@ private function expr(): ExpressionBuilder private function handleJoinQuery( DoctrineSchemaMetadataInterface $relationshipMetadata, string $field, - Value $value + Comparison $comparison ): string { $tableName = $relationshipMetadata->getTableName(); $parameterName = $field . '_' . count($this->parameters); $placeholder = ':' . $parameterName; - $value = $this->walkValue($value); + $value = $this->walkValue($comparison->getValue()); $type = $relationshipMetadata->getBindingTypeForColumn($field); if (is_array($value)) { $type += Connection::ARRAY_PARAM_OFFSET; @@ -306,14 +306,14 @@ private function handleJoinQuery( return $this->expr()->in($tableName . '.' . $field, $placeholder); } - return $this->expr()->eq($tableName . '.' . $field, $placeholder); + return $this->expr()->comparison($tableName . '.' . $field, $comparison->getOperator(), $placeholder); } private function handleSubSelectQuery( DoctrineSchemaMetadataInterface $relationshipMetadata, string $foreignField, string $field, - Value $value + Comparison $comparison ): string { $tableName = $relationshipMetadata->getTableName(); $parameterName = $field . '_' . count($this->parameters); @@ -329,7 +329,7 @@ private function handleSubSelectQuery( ), ); - $value = $this->walkValue($value); + $value = $this->walkValue($comparison->getValue()); $type = $relationshipMetadata->getBindingTypeForColumn($field); if (is_array($value)) { $type += Connection::ARRAY_PARAM_OFFSET; diff --git a/tests/integration/Gateway/ExpressionVisitorTest.php b/tests/integration/Gateway/ExpressionVisitorTest.php index 9dcbb94..6e6dd28 100644 --- a/tests/integration/Gateway/ExpressionVisitorTest.php +++ b/tests/integration/Gateway/ExpressionVisitorTest.php @@ -48,16 +48,38 @@ public function testInvalidField(): void $this->expressionVisitor->dispatch(new Comparison('non_existent_field', '=', 'bar')); } - public function testTraversingRelationships(): void + /** + * @dataProvider provideForTraversingRelationships + */ + public function testTraversingRelationships(Comparison $expr, string $expectedResult): void { // Note: This assumes relationship tables are joined before being used. - $expr = new Comparison( - 'relationship_1.relationship_2.relationship_2_foo', - 'IN', - 'bar', - ); $result = $this->expressionVisitor->dispatch($expr); - self::assertSame('relationship_2_table_name.relationship_2_foo = :relationship_2_foo_0', $result); + self::assertSame($expectedResult, $result); + } + + /** + * @return iterable + */ + public static function provideForTraversingRelationships(): iterable + { + yield [ + new Comparison( + 'relationship_1.relationship_2.relationship_2_foo', + 'IN', + 'bar', + ), + 'relationship_2_table_name.relationship_2_foo IN :relationship_2_foo_0', + ]; + + yield [ + new Comparison( + 'relationship_1.relationship_2.relationship_2_foo', + '<', + '2023-11-22 10:00:00', + ), + 'relationship_2_table_name.relationship_2_foo < :relationship_2_foo_0', + ]; } public function testInvalidRelationshipTraversal(): void