Skip to content

Commit

Permalink
feat: index type hierarchy in java files
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Mar 28, 2024
1 parent c0e4516 commit 0f9d773
Show file tree
Hide file tree
Showing 11 changed files with 855 additions and 198 deletions.
9 changes: 9 additions & 0 deletions metals-bench/src/main/scala/bench/MetalsBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import scala.meta.internal.metals.JdkSources
import scala.meta.internal.metals.ReportContext
import scala.meta.internal.metals.logging.MetalsLogger
import scala.meta.internal.mtags.JavaMtags
import scala.meta.internal.mtags.JavaToplevelMtags
import scala.meta.internal.mtags.Mtags
import scala.meta.internal.mtags.OnDemandSymbolIndex
import scala.meta.internal.mtags.ScalaMtags
Expand Down Expand Up @@ -179,6 +180,14 @@ class MetalsBench {
}
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def toplevelJavaMtags(): Unit = {
javaDependencySources.inputs.foreach { input =>
new JavaToplevelMtags(input, includeInnerClasses = true).index()
}
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def indexSources(): Unit = {
Expand Down
2 changes: 2 additions & 0 deletions metals/src/main/resources/db/migration/V6__Delete_indices.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- indexing type hierarchy has changed, so we want to reindex
delete from indexed_jar;
147 changes: 121 additions & 26 deletions mtags/src/main/scala/scala/meta/internal/mtags/JavaToplevelMtags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,31 @@ import scala.meta.internal.semanticdb.SymbolInformation
import scala.meta.internal.tokenizers.Chars._
import scala.meta.internal.tokenizers.Reporter

class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
class JavaToplevelMtags(
val input: Input.VirtualFile,
includeInnerClasses: Boolean
) extends MtagsIndexer {

import JavaToplevelMtags._

val reporter: Reporter = Reporter(input)
val reader: CharArrayReader =
new CharArrayReader(input, dialects.Scala213, reporter)

override def overrides(): List[(String, List[OverriddenSymbol])] =
overridden.result

private val overridden = List.newBuilder[(String, List[OverriddenSymbol])]

private def addOverridden(symbols: List[OverriddenSymbol]) =
overridden += ((currentOwner, symbols))

override def language: Language = Language.JAVA

override def indexRoot(): Unit = {
if (!input.path.endsWith("module-info.java")) {
reader.nextRawChar()
loop
loop(None)
}
}

Expand All @@ -35,29 +46,90 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
}
}

private def loop: Unit = {
@tailrec
private def loop(region: Option[Region]): Unit = {
val token = fetchToken
token match {
case Token.EOF =>
case Token.Package =>
val paths = readPaths
paths.foreach { path => pkg(path.value, path.pos) }
loop
loop(region)
case Token.Class | Token.Interface | _: Token.Enum | _: Token.Record =>
fetchToken match {
case Token.Word(v, pos) =>
val kind = token match {
case Token.Interface => SymbolInformation.Kind.INTERFACE
case _ => SymbolInformation.Kind.CLASS
}
withOwner(currentOwner)(tpe(v, pos, kind, 0))
skipBody
loop
val previousOwner = currentOwner
tpe(v, pos, kind, 0)
if (includeInnerClasses) {
collectTypeHierarchyInformation
loop(Some(Region(region, currentOwner, lBraceCount = 1)))
} else {
skipBody
currentOwner = previousOwner
loop(region)
}
case Token.LBrace =>
loop(region.map(_.lBrace()))
case Token.RBrace =>
val newRegion = region.flatMap(_.rBrace())
newRegion.foreach(reg => currentOwner = reg.owner)
loop(newRegion)
case _ =>
loop
loop(region)
}
case Token.LBrace =>
loop(region.map(_.lBrace()))
case Token.RBrace =>
val newRegion = region.flatMap(_.rBrace())
newRegion.foreach(reg => currentOwner = reg.owner)
loop(newRegion)
case _ =>
loop(region)
}
}

private def collectTypeHierarchyInformation: Unit = {
val implementsOrExtends = List.newBuilder[String]
@tailrec
def skipUntilOptImplementsOrExtends: Token = {
fetchToken match {
case t @ (Token.Implements | Token.Extends) => t
case Token.EOF => Token.EOF
case Token.LBrace => Token.LBrace
case _ => skipUntilOptImplementsOrExtends
}
}

@tailrec
def collectHierarchy: Unit = {
fetchToken match {
case Token.Word(v, _) =>
// emit here
implementsOrExtends += v
collectHierarchy
case Token.LBrace =>
case Token.LParen =>
skipBalanced(Token.LParen, Token.RParen)
collectHierarchy
case Token.LessThan =>
skipBalanced(Token.LessThan, Token.GreaterThan)
collectHierarchy
case Token.EOF =>
case _ => collectHierarchy
}
}

skipUntilOptImplementsOrExtends match {
case Token.Implements | Token.Extends =>
collectHierarchy
addOverridden(
implementsOrExtends.result.distinct.map(UnresolvedOverriddenSymbol(_))
)
case _ =>
loop
}
}

Expand Down Expand Up @@ -103,6 +175,8 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
case "interface" => Token.Interface
case "record" => Token.Record(pos)
case "enum" => Token.Enum(pos)
case "extends" => Token.Extends
case "implements" => Token.Implements
case ident =>
Token.Word(ident, pos)
}
Expand All @@ -113,8 +187,8 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
def parseToken: (Token, Boolean) = {
val first = reader.ch
first match {
case ',' | '<' | '>' | '&' | '|' | '!' | '=' | '+' | '-' | '*' | '@' |
':' | '?' | '%' | '^' | '~' =>
case ',' | '&' | '|' | '!' | '=' | '+' | '-' | '*' | '@' | ':' | '?' |
'%' | '^' | '~' =>
(Token.SpecialSym, false)
case SU => (Token.EOF, false)
case '.' => (Token.Dot, false)
Expand All @@ -125,6 +199,8 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
case ')' => (Token.RParen, false)
case '[' => (Token.LBracket, false)
case ']' => (Token.RBracket, false)
case '<' => (Token.LessThan, false)
case '>' => (Token.GreaterThan, false)
case '"' => (quotedLiteral('"'), false)
case '\'' => (quotedLiteral('\''), false)
case '/' =>
Expand Down Expand Up @@ -190,22 +266,26 @@ class JavaToplevelMtags(val input: Input.VirtualFile) extends MtagsIndexer {
skipToFirstBrace
}

@tailrec
def skipToRbrace(open: Int): Unit = {
fetchToken match {
case Token.RBrace if open == 1 => ()
case Token.RBrace =>
skipToRbrace(open - 1)
case Token.LBrace =>
skipToRbrace(open + 1)
case Token.EOF => ()
case _ =>
skipToRbrace(open)
}
}

skipToFirstBrace
skipToRbrace(1)
skipBalanced(Token.LBrace, Token.RBrace)
}

@tailrec
private def skipBalanced(
openingToken: Token,
closingToken: Token,
open: Int = 1
): Unit = {
fetchToken match {
case t if t == closingToken && open == 1 => ()
case t if t == closingToken =>
skipBalanced(openingToken, closingToken, open - 1)
case t if t == openingToken =>
skipBalanced(openingToken, closingToken, open + 1)
case Token.EOF => ()
case _ =>
skipBalanced(openingToken, closingToken, open)
}
}

private def skipLine: Unit =
Expand Down Expand Up @@ -260,12 +340,16 @@ object JavaToplevelMtags {
case class Record(pos: Position) extends WithPos {
val value: String = "record"
}
case object Implements extends Token
case object Extends extends Token
case object RBrace extends Token
case object LBrace extends Token
case object RParen extends Token
case object LParen extends Token
case object RBracket extends Token
case object LBracket extends Token
case object LessThan extends Token
case object GreaterThan extends Token
case object Semicolon extends Token
// any allowed symbol like `=` , `-` and others
case object SpecialSym extends Token
Expand All @@ -277,4 +361,15 @@ object JavaToplevelMtags {
}

}

case class Region(
previousRegion: Option[Region],
owner: String,
lBraceCount: Int
) {
def lBrace(): Region = Region(previousRegion, owner, lBraceCount + 1)
def rBrace(): Option[Region] =
if (lBraceCount == 1) previousRegion
else Some(Region(previousRegion, owner, lBraceCount - 1))
}
}
4 changes: 2 additions & 2 deletions mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ final class Mtags(implicit rc: ReportContext) {
if (language.isJava || language.isScala) {
val mtags =
if (language.isJava)
new JavaToplevelMtags(input)
new JavaToplevelMtags(input, includeInnerClasses = false)
else
new ScalaToplevelMtags(
input,
Expand Down Expand Up @@ -59,7 +59,7 @@ final class Mtags(implicit rc: ReportContext) {
if (language.isJava || language.isScala) {
val mtags =
if (language.isJava)
new JavaToplevelMtags(input)
new JavaToplevelMtags(input, includeInnerClasses = true)
else
new ScalaToplevelMtags(
input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ final class OnDemandSymbolIndex(
source,
None, {
indexedSources += 1
getOrCreateBucket(dialect).addSourceFile(source, sourceDirectory)
getOrCreateBucket(dialect)
.addSourceFile(source, sourceDirectory, isJava = false)
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SymbolIndexBucket(
if (sourceJars.addEntry(dir.toNIO)) {
dir.listRecursive.toList.flatMap {
case source if source.isScala =>
addSourceFile(source, Some(dir))
addSourceFile(source, Some(dir), isJava = false)
case _ =>
None
}
Expand All @@ -67,13 +67,9 @@ class SymbolIndexBucket(
try {
root.listRecursive.toList.flatMap {
case source if source.isScala =>
addSourceFile(source, None)
addSourceFile(source, None, isJava = false)
case source if source.isJava =>
addJavaSourceFile(source) match {
case Nil => None
case topLevels =>
Some(IndexingResult(source, topLevels, overrides = Nil))
}
addSourceFile(source, None, isJava = true)
case _ =>
None
}
Expand All @@ -100,39 +96,13 @@ class SymbolIndexBucket(
}
}

/* Sometimes source jars have additional nested directories,
* in that case java toplevel is not "trivial".
* See: https://github.com/scalameta/metals/issues/3815
*/
def addJavaSourceFile(source: AbsolutePath): List[String] = {
new JavaToplevelMtags(source.toInput).readPackage match {
case Nil => Nil
case packageParts =>
val className = source.filename.stripSuffix(".java")
val symbol = packageParts.mkString("", "/", s"/$className#")
if (
isTrivialToplevelSymbol(
source.toURI.toString,
symbol,
extension = "java"
)
) Nil
else {
toplevels.updateWith(symbol) {
case Some(acc) => Some(acc + source)
case None => Some(Set(source))
}
List(symbol)
}
}
}

def addSourceFile(
source: AbsolutePath,
sourceDirectory: Option[AbsolutePath]
sourceDirectory: Option[AbsolutePath],
isJava: Boolean
): Option[IndexingResult] = {
val IndexingResult(path, topLevels, overrides) =
indexSource(source, dialect, sourceDirectory)
indexSource(source, dialect, sourceDirectory, isJava)
topLevels.foreach { symbol =>
toplevels.updateWith(symbol) {
case Some(acc) => Some(acc + source)
Expand All @@ -145,7 +115,8 @@ class SymbolIndexBucket(
private def indexSource(
source: AbsolutePath,
dialect: Dialect,
sourceDirectory: Option[AbsolutePath]
sourceDirectory: Option[AbsolutePath],
isJava: Boolean
): IndexingResult = {
val uri = source.toIdeallyRelativeURI(sourceDirectory)
val (doc, overrides) = mtags.indexWithOverrides(source, dialect)
Expand All @@ -155,8 +126,15 @@ class SymbolIndexBucket(
.map(_.symbol)
val topLevels =
if (source.isAmmoniteScript) sourceTopLevels.toList
else
sourceTopLevels.filter(sym => !isTrivialToplevelSymbol(uri, sym)).toList
else if (isJava) {
sourceTopLevels.toList.headOption
.filter(sym => !isTrivialToplevelSymbol(uri, sym, "java"))
.toList
} else {
sourceTopLevels
.filter(sym => !isTrivialToplevelSymbol(uri, sym, "scala"))
.toList
}
IndexingResult(source, topLevels, overrides)
}

Expand Down
Loading

0 comments on commit 0f9d773

Please sign in to comment.