Skip to content

Commit

Permalink
improvement: do not cache presentation compilers for find references
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed May 28, 2024
1 parent 7cced15 commit 99831af
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 44 deletions.
101 changes: 64 additions & 37 deletions metals/src/main/scala/scala/meta/internal/metals/Compilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -747,26 +747,38 @@ class Compilers(
}.getOrElse(Future.successful(Nil))

def references(
searchFile: AbsolutePath,
id: BuildTargetIdentifier,
searchFiles: List[AbsolutePath],
includeDefinition: Boolean,
symbol: String,
): Future[List[ReferencesResult]] =
loadCompiler(searchFile)
.map { compiler =>
val uri = searchFile.toURI
val (input, _, adjust) =
sourceAdjustments(uri.toString(), compiler.scalaVersion())
val requestParams = new internal.pc.PcReferencesRequest(
CompilerVirtualFileParams(uri, input.text),
includeDefinition,
JEither.forRight(symbol),
)
compiler
.references(requestParams)
.asScala
.map(_.asScala.map(adjust.adjustReferencesResult).toList)
}
.getOrElse(Future.successful(Nil))
): Future[List[ReferencesResult]] = {
// we filter only Scala files, since `references` for Java are not implemented
val filteredFiles = searchFiles.filter(_.isScala)
val results =
if (filteredFiles.isEmpty) Nil
else
withUncachedCompiler(id) { compiler =>
for {
searchFile <- filteredFiles
} yield {
val uri = searchFile.toURI
val (input, _, adjust) =
sourceAdjustments(uri.toString(), compiler.scalaVersion())
val requestParams = new internal.pc.PcReferencesRequest(
CompilerVirtualFileParams(uri, input.text),
includeDefinition,
JEither.forRight(symbol),
)
compiler
.references(requestParams)
.asScala
.map(_.asScala.map(adjust.adjustReferencesResult).toList)
}
}
.getOrElse(Nil)

Future.sequence(results).map(_.flatten)
}

def extractMethod(
doc: TextDocumentIdentifier,
Expand Down Expand Up @@ -1125,34 +1137,49 @@ class Compilers(

private def loadCompiler(
targetId: BuildTargetIdentifier
): Option[PresentationCompiler] = {
val target = buildTargets.scalaTarget(targetId)
target.flatMap(loadCompilerForTarget)
}
): Option[PresentationCompiler] =
withKeyAndDefault(targetId) { case (key, getCompiler) =>
Option(jcache.computeIfAbsent(key, { _ => getCompiler() }).await)
}

private def loadCompilerForTarget(
scalaTarget: ScalaTarget
): Option[PresentationCompiler] = {
val scalaVersion = scalaTarget.scalaVersion
mtagsResolver.resolve(scalaVersion) match {
case Some(mtags) =>
val out = jcache.computeIfAbsent(
PresentationCompilerKey.ScalaBuildTarget(scalaTarget.info.getId),
{ _ =>
private def withKeyAndDefault[T](
targetId: BuildTargetIdentifier
)(
f: (PresentationCompilerKey, () => MtagsPresentationCompiler) => Option[T]
): Option[T] = {
buildTargets.scalaTarget(targetId).flatMap { scalaTarget =>
val scalaVersion = scalaTarget.scalaVersion
mtagsResolver.resolve(scalaVersion) match {
case Some(mtags) =>
def default() =
workDoneProgress.trackBlocking(
s"${config.icons.sync}Loading presentation compiler"
) {
ScalaLazyCompiler(scalaTarget, mtags, search)
}
},
)
Option(out.await)
case None =>
scribe.warn(s"unsupported Scala ${scalaTarget.scalaVersion}")
None
val key =
PresentationCompilerKey.ScalaBuildTarget(scalaTarget.info.getId)
f(key, default)
case None =>
scribe.warn(s"unsupported Scala ${scalaTarget.scalaVersion}")
None
}
}
}

private def withUncachedCompiler[T](
targetId: BuildTargetIdentifier
)(f: PresentationCompiler => T): Option[T] =
withKeyAndDefault(targetId) { case (key, getCompiler) =>
val (out, shouldShutdown) = Option(jcache.get(key))
.map((_, false))
.getOrElse((getCompiler(), true))
val compiler = Option(out.await)
val result = compiler.map(f)
if (shouldShutdown) compiler.foreach(_.shutdown())
result
}

private def withPCAndAdjustLsp[T](
params: SelectionRangeParams
)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,19 @@ final class ReferenceProvider(
_ = visited.clear()
symbol <- symbols
name = nameFromSymbol(symbol)
searchFile <- pathsForName(buildTarget, name)
if (filterTargetFiles(searchFile) && !visited(searchFile))
pathsMap = pathsForName(buildTarget, name)
id <- pathsMap.keySet
searchFiles = pathsMap(id)
.filter(searchFile =>
filterTargetFiles(searchFile) && !visited(searchFile)
)
.distinct
if searchFiles.nonEmpty
} yield {
visited += searchFile
visited ++= searchFiles
compilers.references(
searchFile,
id,
searchFiles,
includeDeclaration,
symbol,
)
Expand All @@ -384,18 +391,19 @@ final class ReferenceProvider(
private def pathsForName(
buildTarget: BuildTargetIdentifier,
name: String,
): Iterator[AbsolutePath] = {
): Map[BuildTargetIdentifier, List[AbsolutePath]] = {
val allowedBuildTargets = buildTargets.allInverseDependencies(buildTarget)
val visited = scala.collection.mutable.Set.empty[AbsolutePath]
for {
val foundPaths = for {
(path, entry) <- identifierIndex.index.iterator
if allowedBuildTargets.contains(entry.id) &&
entry.bloom.mightContain(name)
sourcePath = AbsolutePath(path)
if !visited(sourcePath)
_ = visited.add(sourcePath)
if sourcePath.exists
} yield sourcePath
} yield (entry.id, sourcePath)
foundPaths.toList.groupMap(_._1)(_._2)
}

/**
Expand Down

0 comments on commit 99831af

Please sign in to comment.