diff --git a/cmd/generate/tables.go b/cmd/generate/tables.go index 46a68b34..ad30cef6 100644 --- a/cmd/generate/tables.go +++ b/cmd/generate/tables.go @@ -67,8 +67,11 @@ var renameProp = map[prop]string{ {"ExecuteCommandParams", "arguments"}: "[]json.RawMessage", {"FoldingRange", "kind"}: "string", - {"Hover", "contents"}: "MarkupContent", - {"InlayHint", "label"}: "[]InlayHintLabelPart", + // Hover.contents intentionally has no override: the generator emits the + // spec union (Or_Hover_contents) so MarkedString / []MarkedString (both + // deprecated, still sent by e.g. jdtls) decode alongside MarkupContent. + // Forcing "MarkupContent" here would make jdtls hover decode to empty. + {"InlayHint", "label"}: "[]InlayHintLabelPart", {"RelatedFullDocumentDiagnosticReport", "relatedDocuments"}: "map[DocumentUri]interface{}", {"RelatedUnchangedDocumentDiagnosticReport", "relatedDocuments"}: "map[DocumentUri]interface{}", diff --git a/go.mod b/go.mod index b875a3fc..18accf4c 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,16 @@ module github.com/isaacphi/mcp-language-server go 1.24.0 require ( - github.com/davecgh/go-spew v1.1.1 github.com/fsnotify/fsnotify v1.9.0 github.com/mark3labs/mcp-go v0.25.0 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/stretchr/testify v1.10.0 - golang.org/x/text v0.25.0 + golang.org/x/text v0.24.0 ) require ( github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/kisielk/errcheck v1.9.0 // indirect @@ -22,7 +22,7 @@ require ( github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/sync v0.14.0 // indirect + golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/telemetry v0.0.0-20240522233618-39ace7a40ae7 // indirect golang.org/x/tools v0.31.0 // indirect diff --git a/go.sum b/go.sum index 9ae54052..83c679d1 100644 --- a/go.sum +++ b/go.sum @@ -44,14 +44,14 @@ golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac h1:TSSpLIG4v+p0rP golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/telemetry v0.0.0-20240522233618-39ace7a40ae7 h1:FemxDzfMUcK2f3YY4H+05K9CDzbSVr2+q/JKN45pey0= golang.org/x/telemetry v0.0.0-20240522233618-39ace7a40ae7/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU= golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ= golang.org/x/vuln v1.1.4 h1:Ju8QsuyhX3Hk8ma3CesTbO8vfJD9EvUBgHvkxHBzj0I= diff --git a/internal/lsp/client.go b/internal/lsp/client.go index fc07059d..da95a188 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -151,6 +151,11 @@ func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) ( Completion: protocol.CompletionClientCapabilities{ CompletionItem: protocol.ClientCompletionItemOptions{}, }, + Hover: &protocol.HoverClientCapabilities{ + // Prefer MarkupContent over deprecated MarkedString + // responses (e.g. from jdtls). + ContentFormat: []protocol.MarkupKind{protocol.Markdown, protocol.PlainText}, + }, CodeLens: &protocol.CodeLensClientCapabilities{ DynamicRegistration: true, }, diff --git a/internal/protocol/tsprotocol.go b/internal/protocol/tsprotocol.go index 07436c18..619dd4e6 100644 --- a/internal/protocol/tsprotocol.go +++ b/internal/protocol/tsprotocol.go @@ -2562,8 +2562,10 @@ type GlobPattern = Or_GlobPattern // (alias) // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#hover type Hover struct { - // The hover's content - Contents MarkupContent `json:"contents"` + // The hover's content. Servers may respond with MarkedString, + // []MarkedString (both deprecated), or MarkupContent — e.g. jdtls + // sends []MarkedString unless markdown contentFormat is advertised. + Contents Or_Hover_contents `json:"contents"` // An optional range inside the text document that is used to // visualize the hover, e.g. by changing the background color. Range Range `json:"range,omitempty"` diff --git a/internal/tools/definition.go b/internal/tools/definition.go index 0ddd3fe0..ca27da8b 100644 --- a/internal/tools/definition.go +++ b/internal/tools/definition.go @@ -10,33 +10,17 @@ import ( ) func ReadDefinition(ctx context.Context, client *lsp.Client, symbolName string) (string, error) { - symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{ - Query: symbolName, - }) + results, err := workspaceSymbols(ctx, client, symbolName) if err != nil { - return "", fmt.Errorf("failed to fetch symbol: %v", err) - } - - results, err := symbolResult.Results() - if err != nil { - return "", fmt.Errorf("failed to parse results: %v", err) + return "", err } var definitions []string for _, symbol := range results { - kind := "" - container := "" - // Skip symbols that we are not looking for. workspace/symbol may return // a large number of fuzzy matches. switch v := symbol.(type) { case *protocol.SymbolInformation: - // SymbolInformation results have richer data. - kind = fmt.Sprintf("Kind: %s\n", protocol.TableKindMap[v.Kind]) - if v.ContainerName != "" { - container = fmt.Sprintf("Container Name: %s\n", v.ContainerName) - } - // Handle different matching strategies based on the search term if strings.Contains(symbolName, ".") { // For qualified names like "Type.Method", require exact match @@ -62,38 +46,53 @@ func ReadDefinition(ctx context.Context, client *lsp.Client, symbolName string) } toolsLogger.Debug("Found symbol: %s", symbol.GetName()) - loc := symbol.GetLocation() - err := client.OpenFile(ctx, loc.URI.Path()) + symPath, err := uriToPath(symbol.GetLocation().URI) if err != nil { + toolsLogger.Error("Error resolving symbol path: %v", err) + continue + } + if err := client.OpenFile(ctx, symPath); err != nil { toolsLogger.Error("Error opening file: %v", err) continue } - banner := "---\n\n" - definition, loc, err := GetFullDefinition(ctx, client, loc) - locationInfo := fmt.Sprintf( - "Symbol: %s\n"+ - "File: %s\n"+ - kind+ - container+ - "Range: L%d:C%d - L%d:C%d\n\n", - symbol.GetName(), - strings.TrimPrefix(string(loc.URI), "file://"), - loc.Range.Start.Line+1, - loc.Range.Start.Character+1, - loc.Range.End.Line+1, - loc.Range.End.Character+1, - ) - + definition, loc, err := GetFullDefinition(ctx, client, symbol.GetLocation()) if err != nil { toolsLogger.Error("Error getting definition: %v", err) continue } - definition = addLineNumbers(definition, int(loc.Range.Start.Line)+1) + // Extract kind/container for both SymbolInformation and WorkspaceSymbol + // results so the header is consistent regardless of which shape the + // server returns. + kind, container := symbolKindAndContainer(symbol) + definitions = append(definitions, formatDefinitionEntry(symbol.GetName(), kind, container, definition, loc)) + } - definitions = append(definitions, banner+locationInfo+definition+"\n") + // Fallback for qualified names that workspace/symbol could not match + // directly. Some servers (e.g. jdtls) only index types, so fully + // qualified names ("com.example.Foo") and member symbols + // ("Class.method") never appear in workspace/symbol results. + if len(definitions) == 0 && strings.Contains(symbolName, ".") { + entries, err := resolveQualifiedEntries(ctx, client, symbolName, func(sym resolvedSymbol) ([]string, error) { + symPath, err := uriToPath(sym.Location.URI) + if err != nil { + return nil, err + } + if err := client.OpenFile(ctx, symPath); err != nil { + return nil, fmt.Errorf("error opening file: %v", err) + } + definition, loc, err := GetFullDefinition(ctx, client, sym.Location) + if err != nil { + return nil, fmt.Errorf("error getting definition: %v", err) + } + return []string{formatDefinitionEntry(sym.Name, sym.Kind, sym.ContainerName, definition, loc)}, nil + }) + if err != nil { + return "", err + } + definitions = append(definitions, entries...) } if len(definitions) == 0 { @@ -102,3 +101,35 @@ func ReadDefinition(ctx context.Context, client *lsp.Client, symbolName string) return strings.Join(definitions, ""), nil } + +// formatDefinitionEntry renders a single definition block. It is shared by the +// workspace/symbol path and the qualified-name fallback so the header format +// stays in one place. kindName/containerName are raw names; their labels are +// emitted only when non-empty. +func formatDefinitionEntry(name, kindName, containerName, definition string, loc protocol.Location) string { + var b strings.Builder + fmt.Fprintf(&b, "Symbol: %s\n", name) + // Display the filesystem path for file:// URIs; for anything else (e.g. a + // jdt:// URI from jdtls) fall back to the raw URI rather than panicking. + displayPath, err := uriToPath(loc.URI) + if err != nil { + displayPath = string(loc.URI) + } + fmt.Fprintf(&b, "File: %s\n", displayPath) + // kindName/containerName are server-provided; pass them as args (never as + // part of the format string) so a '%' in a name is not read as a verb. + if kindName != "" { + fmt.Fprintf(&b, "Kind: %s\n", kindName) + } + if containerName != "" { + fmt.Fprintf(&b, "Container Name: %s\n", containerName) + } + fmt.Fprintf(&b, "Range: L%d:C%d - L%d:C%d\n\n", + loc.Range.Start.Line+1, + loc.Range.Start.Character+1, + loc.Range.End.Line+1, + loc.Range.End.Character+1, + ) + + return "---\n\n" + b.String() + addLineNumbers(definition, int(loc.Range.Start.Line)+1) + "\n" +} diff --git a/internal/tools/hover.go b/internal/tools/hover.go index 874d5278..51059ab8 100644 --- a/internal/tools/hover.go +++ b/internal/tools/hover.go @@ -38,29 +38,54 @@ func GetHoverInfo(ctx context.Context, client *lsp.Client, filePath string, line var result strings.Builder - // Process the hover contents based on Markup content - if hoverResult.Contents.Value == "" { - // Extract the line where the hover was requested - lineText, err := ExtractTextFromLocation(protocol.Location{ - URI: uri, - Range: protocol.Range{ - Start: protocol.Position{ - Line: position.Line, - Character: 0, - }, - End: protocol.Position{ - Line: position.Line + 1, - Character: 0, - }, - }, - }) + // Process the hover contents: MarkedString | []MarkedString | MarkupContent + contents := renderHoverContents(hoverResult.Contents) + if contents == "" { + // Extract the line where the hover was requested. Use a single-line + // range (start and end on the same line) so this also works when the + // symbol is on the file's last line; spanning to the start of the next + // line would push End.Line past the last index and fail the bounds + // check in ExtractTextFromLocation. + lineText, err := extractLineText(uri, position.Line) if err != nil { toolsLogger.Warn("failed to extract line at position: %v", err) } result.WriteString(fmt.Sprintf("No hover information available for this position on the following line:\n%s", lineText)) } else { - result.WriteString(hoverResult.Contents.Value) + result.WriteString(contents) } return result.String(), nil } + +// renderHoverContents converts any of the three hover content shapes +// (MarkupContent, MarkedString, []MarkedString) to a markdown string. +func renderHoverContents(contents protocol.Or_Hover_contents) string { + switch v := contents.Value.(type) { + case protocol.MarkupContent: + return v.Value + case protocol.MarkedString: + return renderMarkedString(v) + case []protocol.MarkedString: + parts := make([]string, 0, len(v)) + for _, ms := range v { + if s := renderMarkedString(ms); s != "" { + parts = append(parts, s) + } + } + return strings.Join(parts, "\n\n") + } + return "" +} + +// renderMarkedString converts a MarkedString (plain string or +// {language, value} code block) to markdown. +func renderMarkedString(ms protocol.MarkedString) string { + switch v := ms.Value.(type) { + case string: + return v + case protocol.MarkedStringWithLanguage: + return fmt.Sprintf("```%s\n%s\n```", v.Language, v.Value) + } + return "" +} diff --git a/internal/tools/hover_test.go b/internal/tools/hover_test.go new file mode 100644 index 00000000..ad5bab1d --- /dev/null +++ b/internal/tools/hover_test.go @@ -0,0 +1,92 @@ +package tools + +import ( + "testing" + + "github.com/isaacphi/mcp-language-server/internal/protocol" + "github.com/stretchr/testify/assert" +) + +// TestRenderHoverContents covers the three hover content shapes the renderer +// must handle: MarkupContent ({kind, value}), a plain-string MarkedString, and +// a []MarkedString containing a MarkedStringWithLanguage (which must be wrapped +// in a language-tagged code fence). These are the shapes jdtls and other +// servers emit, so the fix's formatting is asserted here without a live LSP. +func TestRenderHoverContents(t *testing.T) { + tests := []struct { + name string + contents protocol.Or_Hover_contents + want string + }{ + { + name: "MarkupContent returns its value verbatim", + contents: protocol.Or_Hover_contents{Value: protocol.MarkupContent{ + Kind: protocol.Markdown, + Value: "# Title\nbody", + }}, + want: "# Title\nbody", + }, + { + name: "plain-string MarkedString returns the string", + contents: protocol.Or_Hover_contents{Value: protocol.MarkedString{ + Value: "just a string", + }}, + want: "just a string", + }, + { + name: "[]MarkedString with MarkedStringWithLanguage is code-fenced", + contents: protocol.Or_Hover_contents{Value: []protocol.MarkedString{ + {Value: protocol.MarkedStringWithLanguage{Language: "java", Value: "int x"}}, + {Value: "documentation text"}, + }}, + want: "```java\nint x\n```\n\ndocumentation text", + }, + { + // A top-level MarkedStringWithLanguage (not inside a slice) goes + // through the distinct `case protocol.MarkedString` branch and must + // still be wrapped in a language-tagged code fence. + name: "top-level MarkedStringWithLanguage is code-fenced", + contents: protocol.Or_Hover_contents{Value: protocol.MarkedString{ + Value: protocol.MarkedStringWithLanguage{Language: "java", Value: "int x"}, + }}, + want: "```java\nint x\n```", + }, + { + // Empty []MarkedString entries are dropped so they do not produce + // stray blank separators in the joined output. + name: "[]MarkedString drops empty entries", + contents: protocol.Or_Hover_contents{Value: []protocol.MarkedString{ + {Value: "first"}, + {Value: ""}, + {Value: "second"}, + }}, + want: "first\n\nsecond", + }, + { + name: "unknown shape returns empty string", + contents: protocol.Or_Hover_contents{Value: 42}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := renderHoverContents(tt.contents) + assert.Equal(t, tt.want, got) + }) + } +} + +// TestRenderMarkedString covers the two MarkedString variants directly: a plain +// string passes through unchanged, and a MarkedStringWithLanguage becomes a +// language-tagged code fence. +func TestRenderMarkedString(t *testing.T) { + assert.Equal(t, "plain", renderMarkedString(protocol.MarkedString{Value: "plain"})) + assert.Equal(t, + "```go\nfunc f()\n```", + renderMarkedString(protocol.MarkedString{ + Value: protocol.MarkedStringWithLanguage{Language: "go", Value: "func f()"}, + }), + ) + assert.Equal(t, "", renderMarkedString(protocol.MarkedString{Value: 0})) +} diff --git a/internal/tools/lsp-utilities.go b/internal/tools/lsp-utilities.go index ae7d70bc..ec4a28d9 100644 --- a/internal/tools/lsp-utilities.go +++ b/internal/tools/lsp-utilities.go @@ -3,7 +3,6 @@ package tools import ( "context" "fmt" - "net/url" "os" "strings" @@ -59,10 +58,12 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr found = searchSymbols(symbols) if found { - // Convert URI to filesystem path - filePath, err := url.PathUnescape(strings.TrimPrefix(string(startLocation.URI), "file://")) + // Convert URI to filesystem path. Use the shared uriToPath helper so + // percent-decoding and non-file URI handling stay consistent with the + // definition/references tools (and never panics on a jdt:// URI). + filePath, err := uriToPath(startLocation.URI) if err != nil { - return "", protocol.Location{}, fmt.Errorf("failed to unescape URI: %w", err) + return "", protocol.Location{}, err } // Read the file to get the full lines of the definition diff --git a/internal/tools/references.go b/internal/tools/references.go index cb424e55..a0d28ee4 100644 --- a/internal/tools/references.go +++ b/internal/tools/references.go @@ -2,6 +2,7 @@ package tools import ( "context" + "errors" "fmt" "os" "sort" @@ -12,6 +13,15 @@ import ( "github.com/isaacphi/mcp-language-server/internal/protocol" ) +// openFileError marks a collectReferencesAt failure that originated from being +// unable to open the symbol's file (stale index, deleted/moved file, +// permissions). Such a failure is recoverable per match; a references RPC +// failure is not and is returned as a plain error. +type openFileError struct{ err error } + +func (e *openFileError) Error() string { return e.err.Error() } +func (e *openFileError) Unwrap() error { return e.err } + func FindReferences(ctx context.Context, client *lsp.Client, symbolName string) (string, error) { // Get context lines from environment variable contextLines := 5 @@ -22,19 +32,40 @@ func FindReferences(ctx context.Context, client *lsp.Client, symbolName string) } // First get the symbol location like ReadDefinition does - symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{ - Query: symbolName, - }) + results, err := workspaceSymbols(ctx, client, symbolName) if err != nil { - return "", fmt.Errorf("failed to fetch symbol: %v", err) + return "", err } - results, err := symbolResult.Results() - if err != nil { - return "", fmt.Errorf("failed to parse results: %v", err) + var allReferences []string + // hardErr records the most recent genuine references-RPC failure (server + // crash, timeout, protocol error). It is surfaced only if no references + // were collected at all: a single flaky match must neither discard + // references already resolved from other matches nor be silently reported + // as "no references found". The main loop and the qualified-name fallback + // share this policy via collect. + var hardErr error + collect := func(loc protocol.Location, name string) { + formatted, err := collectReferencesAt(ctx, client, loc, contextLines) + if err != nil { + var ofe *openFileError + if errors.As(err, &ofe) { + // A single un-openable file (stale index, deleted/moved file, + // permissions) must not discard references resolvable from the + // remaining matches. + toolsLogger.Error("Error opening file for %s: %v", name, err) + return + } + // A genuine references RPC failure is a real error, not an empty + // result. Remember it but keep going so references from other + // matches survive; it is surfaced below only if nothing resolved. + toolsLogger.Error("Error getting references for %s: %v", name, err) + hardErr = fmt.Errorf("failed to get references for %s: %w", name, err) + return + } + allReferences = append(allReferences, formatted...) } - var allReferences []string for _, symbol := range results { // Handle different matching strategies based on the search term if strings.Contains(symbolName, ".") { @@ -51,101 +82,140 @@ func FindReferences(ctx context.Context, client *lsp.Client, symbolName string) continue } - // Get the location of the symbol - loc := symbol.GetLocation() + collect(symbol.GetLocation(), symbol.GetName()) + } - // Use LSP references request with correct params structure - refsParams := protocol.ReferenceParams{ - TextDocumentPositionParams: protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: loc.URI, - }, - Position: loc.Range.Start, - }, - Context: protocol.ReferenceContext{ - IncludeDeclaration: false, - }, - } - // File is likely to be opened already, but may not be. - err := client.OpenFile(ctx, loc.URI.Path()) + // Fallback for qualified names that workspace/symbol could not match + // directly. Some servers (e.g. jdtls) only index types, so member + // symbols like "Class.method" never appear in workspace/symbol results. + if len(allReferences) == 0 && strings.Contains(symbolName, ".") { + resolved, err := resolveQualifiedSymbol(ctx, client, symbolName) if err != nil { - toolsLogger.Error("Error opening file: %v", err) - continue + return "", err } - refs, err := client.References(ctx, refsParams) - if err != nil { - return "", fmt.Errorf("failed to get references: %v", err) + for _, sym := range resolved { + collect(sym.Location, sym.Name) } + } - // Group references by file - refsByFile := make(map[protocol.DocumentUri][]protocol.Location) - for _, ref := range refs { - refsByFile[ref.URI] = append(refsByFile[ref.URI], ref) + if len(allReferences) == 0 { + // Surface a genuine RPC failure instead of masking it as an empty + // result; otherwise report that nothing was found. + if hardErr != nil { + return "", hardErr } + return fmt.Sprintf("No references found for symbol: %s", symbolName), nil + } + + return strings.Join(allReferences, "\n"), nil +} + +// collectReferencesAt runs textDocument/references for the symbol at loc and +// formats the results grouped by file. +func collectReferencesAt(ctx context.Context, client *lsp.Client, loc protocol.Location, contextLines int) ([]string, error) { + // Use LSP references request with correct params structure + refsParams := protocol.ReferenceParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: loc.URI, + }, + Position: loc.Range.Start, + }, + Context: protocol.ReferenceContext{ + IncludeDeclaration: false, + }, + } + // File is likely to be opened already, but may not be. A failure here is + // reported as an openFileError so callers can skip this match rather than + // treating it as a hard references failure. + openPath, err := uriToPath(loc.URI) + if err != nil { + return nil, &openFileError{err: err} + } + if err := client.OpenFile(ctx, openPath); err != nil { + return nil, &openFileError{err: fmt.Errorf("error opening file: %v", err)} + } + refs, err := client.References(ctx, refsParams) + if err != nil { + return nil, fmt.Errorf("failed to get references: %v", err) + } + + var allReferences []string - // Get sorted list of URIs - uris := make([]string, 0, len(refsByFile)) - for uri := range refsByFile { - uris = append(uris, string(uri)) + // Group references by file + refsByFile := make(map[protocol.DocumentUri][]protocol.Location) + for _, ref := range refs { + refsByFile[ref.URI] = append(refsByFile[ref.URI], ref) + } + + // Get sorted list of URIs + uris := make([]string, 0, len(refsByFile)) + for uri := range refsByFile { + uris = append(uris, string(uri)) + } + sort.Strings(uris) + + // Process each file's references in sorted order + for _, uriStr := range uris { + uri := protocol.DocumentUri(uriStr) + fileRefs := refsByFile[uri] + // Use uriToPath rather than trimming the "file://" prefix: it + // percent-decodes the path (e.g. "%20" -> space), so os.ReadFile + // finds files whose paths contain encoded characters, and the displayed + // path stays consistent with definition.go. A non-file URI (e.g. a + // jdt:// reference into a JAR) cannot be read from disk, so skip it + // rather than panicking. + filePath, err := uriToPath(uri) + if err != nil { + toolsLogger.Error("Skipping references in non-file URI %s: %v", uri, err) + continue } - sort.Strings(uris) - - // Process each file's references in sorted order - for _, uriStr := range uris { - uri := protocol.DocumentUri(uriStr) - fileRefs := refsByFile[uri] - filePath := strings.TrimPrefix(uriStr, "file://") - - // Format file header - fileInfo := fmt.Sprintf("---\n\n%s\nReferences in File: %d\n", - filePath, - len(fileRefs), - ) - - // Format locations with context - fileContent, err := os.ReadFile(filePath) - if err != nil { - // Log error but continue with other files - allReferences = append(allReferences, fileInfo+"\nError reading file: "+err.Error()) - continue - } - lines := strings.Split(string(fileContent), "\n") + // Format file header + fileInfo := fmt.Sprintf("---\n\n%s\nReferences in File: %d\n", + filePath, + len(fileRefs), + ) - // Track reference locations for header display - var locStrings []string - for _, ref := range fileRefs { - locStr := fmt.Sprintf("L%d:C%d", - ref.Range.Start.Line+1, - ref.Range.Start.Character+1) - locStrings = append(locStrings, locStr) - } + // Format locations with context + fileContent, err := os.ReadFile(filePath) + if err != nil { + // Log error but continue with other files + allReferences = append(allReferences, fileInfo+"\nError reading file: "+err.Error()) + continue + } - // Collect lines to display using the utility function - linesToShow, err := GetLineRangesToDisplay(ctx, client, fileRefs, len(lines), contextLines) - if err != nil { - // Log error but continue with other files - continue - } + lines := strings.Split(string(fileContent), "\n") - // Convert to line ranges using the utility function - lineRanges := ConvertLinesToRanges(linesToShow, len(lines)) + // Track reference locations for header display + var locStrings []string + for _, ref := range fileRefs { + locStr := fmt.Sprintf("L%d:C%d", + ref.Range.Start.Line+1, + ref.Range.Start.Character+1) + locStrings = append(locStrings, locStr) + } - // Format with locations in header - formattedOutput := fileInfo - if len(locStrings) > 0 { - formattedOutput += "At: " + strings.Join(locStrings, ", ") + "\n" - } + // Collect lines to display using the utility function + linesToShow, err := GetLineRangesToDisplay(ctx, client, fileRefs, len(lines), contextLines) + if err != nil { + // Log error but continue with other files + continue + } + + // Convert to line ranges using the utility function + lineRanges := ConvertLinesToRanges(linesToShow, len(lines)) - // Format the content with ranges - formattedOutput += "\n" + FormatLinesWithRanges(lines, lineRanges) - allReferences = append(allReferences, formattedOutput) + // Format with locations in header + formattedOutput := fileInfo + if len(locStrings) > 0 { + formattedOutput += "At: " + strings.Join(locStrings, ", ") + "\n" } - } - if len(allReferences) == 0 { - return fmt.Sprintf("No references found for symbol: %s", symbolName), nil + // Format the content with ranges + formattedOutput += "\n" + FormatLinesWithRanges(lines, lineRanges) + allReferences = append(allReferences, formattedOutput) } - return strings.Join(allReferences, "\n"), nil + return allReferences, nil } diff --git a/internal/tools/symbol-resolve.go b/internal/tools/symbol-resolve.go new file mode 100644 index 00000000..ea182fc2 --- /dev/null +++ b/internal/tools/symbol-resolve.go @@ -0,0 +1,329 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" +) + +// resolvedSymbol is a symbol location resolved via the qualified-name +// fallback (workspace/symbol for the container + documentSymbol for the +// member). Some servers (e.g. jdtls) only index types in workspace/symbol, +// so member symbols like "Class.method" are unreachable without this. +type resolvedSymbol struct { + Name string + Kind string + ContainerName string + Location protocol.Location +} + +// resolveQualifiedSymbol resolves qualified symbol names that +// workspace/symbol could not match directly. It handles two shapes: +// +// 1. Fully qualified top-level symbols, e.g. "com.example.Foo": +// matches workspace/symbol results for "Foo" whose container is +// "com.example". +// 2. Members of a type, e.g. "Foo.bar" or "com.example.Foo.bar": +// resolves the type via workspace/symbol, then locates "bar" with +// textDocument/documentSymbol in the type's file. +func resolveQualifiedSymbol(ctx context.Context, client *lsp.Client, symbolName string) ([]resolvedSymbol, error) { + parts := strings.Split(symbolName, ".") + if len(parts) < 2 { + return nil, nil + } + member := parts[len(parts)-1] + container := strings.Join(parts[:len(parts)-1], ".") + + // Case 1: symbolName is a fully qualified top-level symbol. + symbols, err := workspaceSymbols(ctx, client, member) + if err != nil { + return nil, err + } + var resolved []resolvedSymbol + for _, sym := range symbols { + if sym.GetName() != member { + continue + } + kind, symContainer := symbolKindAndContainer(sym) + if symContainer == container { + resolved = append(resolved, resolvedSymbol{ + Name: member, + Kind: kind, + ContainerName: symContainer, + Location: sym.GetLocation(), + }) + } + } + if len(resolved) > 0 { + return resolved, nil + } + + // Case 2: container is a type; locate the member via documentSymbol. + simpleContainer := parts[len(parts)-2] + containerPkg := strings.Join(parts[:len(parts)-2], ".") + symbols, err = workspaceSymbols(ctx, client, simpleContainer) + if err != nil { + return nil, err + } + for _, sym := range symbols { + if sym.GetName() != simpleContainer { + continue + } + if containerPkg != "" { + // Suffix-tolerant, consistent with the rest of this file: + // jdtls reports a nested type's containerName as its fully + // qualified enclosing scope (e.g. "com.example.Outer" for the + // query "Outer.Inner.method"), so an exact "== containerPkg" + // check would reject the legitimate match. flatContainerMatches + // also accepts the exact and "Outer$Inner" forms, so gopls/pyright + // flat results keep matching. + if _, symContainer := symbolKindAndContainer(sym); !flatContainerMatches(symContainer, containerPkg) { + continue + } + } + loc := sym.GetLocation() + members, err := findMembersInDocument(ctx, client, loc.URI, simpleContainer, member) + if err != nil { + toolsLogger.Error("documentSymbol fallback failed for %s: %v", loc.URI, err) + continue + } + for _, m := range members { + m.ContainerName = sym.GetName() + resolved = append(resolved, m) + } + } + // Dedup across matching workspace symbols: scopeMembersToContainer only + // dedups within a single document, so a type indexed under duplicate + // entries (same file) could otherwise yield the same member twice. + return dedupSymbols(resolved), nil +} + +// resolveQualifiedEntries runs the qualified-name fallback and renders each +// resolved symbol with render, accumulating the non-empty output. It is the +// shared scaffold for the definition and references fallbacks so their trigger +// and iteration semantics stay identical. A resolve error is propagated; a +// per-symbol render error is logged and skipped, since the fallback is a +// best-effort path over speculatively resolved symbols. +func resolveQualifiedEntries(ctx context.Context, client *lsp.Client, symbolName string, render func(sym resolvedSymbol) ([]string, error)) ([]string, error) { + resolved, err := resolveQualifiedSymbol(ctx, client, symbolName) + if err != nil { + return nil, err + } + var out []string + for _, sym := range resolved { + entries, err := render(sym) + if err != nil { + toolsLogger.Error("qualified-name fallback failed for %s: %v", sym.Name, err) + continue + } + out = append(out, entries...) + } + return out, nil +} + +// workspaceSymbols is a package-level variable so tests can stub the +// workspace/symbol query without a live LSP client. Production code always +// uses the default below. Because it is shared global state, tests that +// reassign it must restore it (via t.Cleanup) and must not run with +// t.Parallel(). +var workspaceSymbols = func(ctx context.Context, client *lsp.Client, query string) ([]protocol.WorkspaceSymbolResult, error) { + symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{Query: query}) + if err != nil { + return nil, fmt.Errorf("failed to fetch symbol: %v", err) + } + return symbolResult.Results() +} + +// symbolKindAndContainer extracts the symbol kind name and container name +// from either workspace/symbol result type. +func symbolKindAndContainer(sym protocol.WorkspaceSymbolResult) (string, string) { + switch v := sym.(type) { + case *protocol.SymbolInformation: + return protocol.TableKindMap[v.Kind], v.ContainerName + case *protocol.WorkspaceSymbol: + return protocol.TableKindMap[v.Kind], v.ContainerName + } + return "", "" +} + +// findMembersInDocument finds symbols named member that belong to the type +// named container, within the document symbol tree of uri. Scoping to the +// container prevents matching same-named members of unrelated types in the +// same file. Matches either the exact name or a signature-suffixed name like +// "member(String, int)" as jdtls reports methods. +// findMembersInDocument is a package-level variable so tests can stub the +// document-symbol fetch without a live LSP client. Production code always uses +// the default below. Because it is shared global state, tests that reassign it +// must restore it (via t.Cleanup) and must not run with t.Parallel(). +var findMembersInDocument = func(ctx context.Context, client *lsp.Client, uri protocol.DocumentUri, container, member string) ([]resolvedSymbol, error) { + path, err := uriToPath(uri) + if err != nil { + return nil, err + } + if err := client.OpenFile(ctx, path); err != nil { + return nil, err + } + symResult, err := client.DocumentSymbol(ctx, protocol.DocumentSymbolParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: uri}, + }) + if err != nil { + return nil, err + } + symbols, err := symResult.Results() + if err != nil { + return nil, err + } + return scopeMembersToContainer(symbols, uri, container, member), nil +} + +// memberMatches reports whether a document-symbol name matches member, either +// exactly or as a signature-suffixed method name like "member(String, int)". +func memberMatches(name, member string) bool { + return name == member || strings.HasPrefix(name, member+"(") +} + +// containerMatches reports whether a type symbol name matches container, +// tolerating a generic suffix like "Foo". +func containerMatches(name, container string) bool { + return name == container || strings.HasPrefix(name, container+"<") +} + +// resolvedMember builds a resolvedSymbol from a document symbol matched as a +// member. The location is anchored at a single position because downstream +// definition/references requests are position-based. +// +// For DocumentSymbol it uses SelectionRange (the identifier itself) and the +// document uri, since DocumentSymbol carries no URI of its own. For +// SymbolInformation it uses the symbol's own Location — both the URI (which may +// differ from the document we queried, e.g. generated/partial sources) and the +// range start. SymbolInformation has no selection range, so the range start is +// the best available anchor. +func resolvedMember(sym protocol.DocumentSymbolResult, uri protocol.DocumentUri) resolvedSymbol { + var loc protocol.Location + kind := "" + switch v := sym.(type) { + case *protocol.DocumentSymbol: + pos := v.SelectionRange.Start + loc = protocol.Location{URI: uri, Range: protocol.Range{Start: pos, End: pos}} + kind = protocol.TableKindMap[v.Kind] + case *protocol.SymbolInformation: + pos := v.Location.Range.Start + symURI := v.Location.URI + if symURI == "" { + symURI = uri + } + loc = protocol.Location{URI: symURI, Range: protocol.Range{Start: pos, End: pos}} + kind = protocol.TableKindMap[v.Kind] + } + return resolvedSymbol{ + Name: sym.GetName(), + Kind: kind, + Location: loc, + } +} + +// scopeMembersToContainer finds members named member that belong to the type +// named container within symbols. It handles both hierarchical DocumentSymbol +// trees (members are direct children of the type node) and flat +// SymbolInformation lists (members carry a ContainerName). If the container +// type cannot be located at all, it falls back to matching member anywhere in +// the document so resolution is not lost for servers with unusual symbol +// shapes. +func scopeMembersToContainer(symbols []protocol.DocumentSymbolResult, uri protocol.DocumentUri, container, member string) []resolvedSymbol { + var found []resolvedSymbol + var anywhere []resolvedSymbol + containerSeen := false + + // A single depth-first walk handles both shapes and, in the same pass, + // gathers the document-wide matches used by the safety fallback so the + // tree is never traversed twice: + // - Hierarchical DocumentSymbol: a type node named container holds its + // members as direct children. Recursion reaches nested types. + // - Flat SymbolInformation: members carry a ContainerName instead. + walkDocumentSymbols(symbols, func(sym protocol.DocumentSymbolResult) { + if memberMatches(sym.GetName(), member) { + anywhere = append(anywhere, resolvedMember(sym, uri)) + } + switch v := sym.(type) { + case *protocol.DocumentSymbol: + if containerMatches(v.Name, container) { + containerSeen = true + for i := range v.Children { + if child := &v.Children[i]; memberMatches(child.Name, member) { + found = append(found, resolvedMember(child, uri)) + } + } + } + case *protocol.SymbolInformation: + if flatContainerMatches(v.ContainerName, container) { + containerSeen = true + if memberMatches(v.Name, member) { + found = append(found, resolvedMember(v, uri)) + } + } + } + }) + + // containerSeen alone is sufficient: found is only appended to inside + // branches that have already set containerSeen, so len(found) > 0 implies it. + if containerSeen { + return dedupSymbols(found) + } + + // Safety fallback: container not found; match member anywhere so we don't + // regress resolution for servers with unusual document-symbol shapes. This + // is a degraded path that can return members of unrelated types, so log it + // at Warn rather than hiding it at Debug. + toolsLogger.Warn("container %q not found in %s; matching member %q document-wide", container, uri, member) + return dedupSymbols(anywhere) +} + +// walkDocumentSymbols invokes visit for every symbol in the tree, depth-first, +// descending into the children of any hierarchical DocumentSymbol node. It is +// the single traversal shared by the container-scoped and document-wide member +// searches. +func walkDocumentSymbols(symbols []protocol.DocumentSymbolResult, visit func(protocol.DocumentSymbolResult)) { + for _, sym := range symbols { + visit(sym) + if ds, ok := sym.(*protocol.DocumentSymbol); ok && len(ds.Children) > 0 { + children := make([]protocol.DocumentSymbolResult, len(ds.Children)) + for i := range ds.Children { + children[i] = &ds.Children[i] + } + walkDocumentSymbols(children, visit) + } + } +} + +// dedupSymbols removes entries that resolve to the same name at the same +// location, which can happen when overlapping symbols are reported more than +// once. Genuinely distinct symbols (different positions) are preserved. +func dedupSymbols(syms []resolvedSymbol) []resolvedSymbol { + if len(syms) <= 1 { + return syms + } + seen := make(map[string]bool, len(syms)) + out := make([]resolvedSymbol, 0, len(syms)) + for _, s := range syms { + key := fmt.Sprintf("%s\x00%d\x00%d\x00%s", + s.Location.URI, s.Location.Range.Start.Line, s.Location.Range.Start.Character, s.Name) + if seen[key] { + continue + } + seen[key] = true + out = append(out, s) + } + return out +} + +// flatContainerMatches reports whether a SymbolInformation ContainerName +// (which may be qualified, e.g. "com.example.Foo" or "Outer$Inner") refers to +// the type named container. +func flatContainerMatches(containerName, container string) bool { + return containerMatches(containerName, container) || + strings.HasSuffix(containerName, "."+container) || + strings.HasSuffix(containerName, "$"+container) +} diff --git a/internal/tools/symbol-resolve_test.go b/internal/tools/symbol-resolve_test.go new file mode 100644 index 00000000..32e3917b --- /dev/null +++ b/internal/tools/symbol-resolve_test.go @@ -0,0 +1,457 @@ +package tools + +import ( + "context" + "errors" + "testing" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" + "github.com/stretchr/testify/assert" +) + +const testURI = protocol.DocumentUri("file:///test.java") + +// docSym builds a hierarchical DocumentSymbol. Range.Start uses startLine while +// SelectionRange.Start (what resolvedMember reports) uses selLine so tests can +// assert which node matched. +func docSym(name string, kind protocol.SymbolKind, selLine uint32, children ...protocol.DocumentSymbol) protocol.DocumentSymbol { + return protocol.DocumentSymbol{ + Name: name, + Kind: kind, + Range: protocol.Range{Start: protocol.Position{Line: selLine}}, + SelectionRange: protocol.Range{Start: protocol.Position{Line: selLine, Character: 4}}, + Children: children, + } +} + +// tree wraps top-level DocumentSymbols as a []DocumentSymbolResult. +func tree(syms ...protocol.DocumentSymbol) []protocol.DocumentSymbolResult { + out := make([]protocol.DocumentSymbolResult, len(syms)) + for i := range syms { + out[i] = &syms[i] + } + return out +} + +// flat wraps SymbolInformation values as a []DocumentSymbolResult. +func flat(syms ...protocol.SymbolInformation) []protocol.DocumentSymbolResult { + out := make([]protocol.DocumentSymbolResult, len(syms)) + for i := range syms { + out[i] = &syms[i] + } + return out +} + +func symInfo(name string, kind protocol.SymbolKind, container string, line uint32) protocol.SymbolInformation { + return protocol.SymbolInformation{ + Name: name, + Kind: kind, + ContainerName: container, + Location: protocol.Location{URI: testURI, Range: protocol.Range{Start: protocol.Position{Line: line}}}, + } +} + +func TestResolveQualifiedSymbolCase2(t *testing.T) { + // Case 2 dispatch for a 3-part nested name "Outer.Inner.method": the + // candidate type "Inner" is reported by jdtls with a fully qualified + // containerName ("com.example.Outer"), while the query's containerPkg is + // just "Outer". The suffix-tolerant filter must accept it; an exact + // equality check would reject the legitimate match (the false-negative + // this branch fixes). + t.Run("matches a qualified containerName against a simple containerPkg", func(t *testing.T) { + const innerURI = protocol.DocumentUri("file:///Outer.java") + + origWS := workspaceSymbols + origMembers := findMembersInDocument + t.Cleanup(func() { + workspaceSymbols = origWS + findMembersInDocument = origMembers + }) + + var gotContainer, gotMember string + // Case 1 queries the member ("method") and finds nothing, so resolution + // falls through to Case 2, which queries the simple container ("Inner"). + workspaceSymbols = func(_ context.Context, _ *lsp.Client, query string) ([]protocol.WorkspaceSymbolResult, error) { + if query != "Inner" { + return nil, nil + } + return []protocol.WorkspaceSymbolResult{ + &protocol.WorkspaceSymbol{ + BaseSymbolInformation: protocol.BaseSymbolInformation{ + Name: "Inner", + Kind: protocol.Class, + ContainerName: "com.example.Outer", + }, + Location: protocol.Or_WorkspaceSymbol_location{Value: protocol.Location{URI: innerURI}}, + }, + }, nil + } + findMembersInDocument = func(_ context.Context, _ *lsp.Client, uri protocol.DocumentUri, container, member string) ([]resolvedSymbol, error) { + gotContainer, gotMember = container, member + assert.Equal(t, innerURI, uri) + return []resolvedSymbol{{ + Name: "method", + Kind: "Method", + Location: protocol.Location{URI: uri, Range: protocol.Range{Start: protocol.Position{Line: 15}}}, + }}, nil + } + + got, err := resolveQualifiedSymbol(context.Background(), nil, "Outer.Inner.method") + assert.NoError(t, err) + assert.Len(t, got, 1) + assert.Equal(t, "method", got[0].Name) + assert.Equal(t, uint32(15), got[0].Location.Range.Start.Line) + assert.Equal(t, "Inner", got[0].ContainerName, "container is set to the resolved type's name") + assert.Equal(t, "Inner", gotContainer, "member lookup is scoped to the simple container") + assert.Equal(t, "method", gotMember) + }) + + t.Run("rejects a candidate whose containerPkg does not match", func(t *testing.T) { + origWS := workspaceSymbols + origMembers := findMembersInDocument + t.Cleanup(func() { + workspaceSymbols = origWS + findMembersInDocument = origMembers + }) + + // First call: member lookup for Case 1 ("method") finds nothing. + // Second call: container lookup for Case 2 ("Inner") returns a type + // whose container is "com.example.Other" — not the requested "Outer". + workspaceSymbols = func(_ context.Context, _ *lsp.Client, query string) ([]protocol.WorkspaceSymbolResult, error) { + if query == "Inner" { + return []protocol.WorkspaceSymbolResult{ + &protocol.WorkspaceSymbol{ + BaseSymbolInformation: protocol.BaseSymbolInformation{ + Name: "Inner", + Kind: protocol.Class, + ContainerName: "com.example.Other", + }, + Location: protocol.Or_WorkspaceSymbol_location{Value: protocol.Location{URI: testURI}}, + }, + }, nil + } + return nil, nil + } + findMembersInDocument = func(context.Context, *lsp.Client, protocol.DocumentUri, string, string) ([]resolvedSymbol, error) { + t.Fatal("findMembersInDocument must not be called for a non-matching containerPkg") + return nil, nil + } + + got, err := resolveQualifiedSymbol(context.Background(), nil, "Outer.Inner.method") + assert.NoError(t, err) + assert.Empty(t, got) + }) +} + +// TestResolveQualifiedSymbolDedupesAcrossWorkspaceResults covers the +// dedupSymbols call after Case 2 collects members across every matching +// workspace/symbol result. scopeMembersToContainer only dedups within a single +// document, so a type indexed under two workspace-symbol entries that resolve +// to the same member location would otherwise be reported twice. This is the +// cross-result dedup the per-document tests do not exercise. +func TestResolveQualifiedSymbolDedupesAcrossWorkspaceResults(t *testing.T) { + origWS := workspaceSymbols + origMembers := findMembersInDocument + t.Cleanup(func() { + workspaceSymbols = origWS + findMembersInDocument = origMembers + }) + + // Case 1 ("baz") finds nothing; Case 2 ("Foo") returns two duplicate type + // entries (as a stale/duplicated index can), each resolving to the same + // member location. + const fooURI = protocol.DocumentUri("file:///Foo.java") + workspaceSymbols = func(_ context.Context, _ *lsp.Client, query string) ([]protocol.WorkspaceSymbolResult, error) { + if query != "Foo" { + return nil, nil + } + entry := &protocol.WorkspaceSymbol{ + BaseSymbolInformation: protocol.BaseSymbolInformation{Name: "Foo", Kind: protocol.Class}, + Location: protocol.Or_WorkspaceSymbol_location{Value: protocol.Location{URI: fooURI}}, + } + return []protocol.WorkspaceSymbolResult{entry, entry}, nil + } + findMembersInDocument = func(_ context.Context, _ *lsp.Client, uri protocol.DocumentUri, _, _ string) ([]resolvedSymbol, error) { + return []resolvedSymbol{{ + Name: "baz", + Kind: "Method", + Location: protocol.Location{URI: uri, Range: protocol.Range{Start: protocol.Position{Line: 10}}}, + }}, nil + } + + got, err := resolveQualifiedSymbol(context.Background(), nil, "Foo.baz") + assert.NoError(t, err) + assert.Len(t, got, 1, "the same member resolved from two workspace results must be collapsed") + assert.Equal(t, "baz", got[0].Name) +} + +// TestResolveQualifiedEntries covers the resolve-and-render scaffold used by +// the definition qualified-name fallback (references.go calls +// resolveQualifiedSymbol directly with its own collect closure because its +// error policy differs): it resolves the symbol, renders each resolved entry, +// and accumulates the output. A resolve error is +// propagated; a per-entry render error is logged and skipped rather than +// aborting the whole fallback. The render func here is deliberately +// client-free so the wiring is exercised without a live LSP. +func TestResolveQualifiedEntries(t *testing.T) { + t.Run("renders every resolved entry", func(t *testing.T) { + origWS := workspaceSymbols + origMembers := findMembersInDocument + t.Cleanup(func() { + workspaceSymbols = origWS + findMembersInDocument = origMembers + }) + + const fooURI = protocol.DocumentUri("file:///Foo.java") + workspaceSymbols = func(_ context.Context, _ *lsp.Client, query string) ([]protocol.WorkspaceSymbolResult, error) { + if query != "Foo" { + return nil, nil + } + return []protocol.WorkspaceSymbolResult{ + &protocol.WorkspaceSymbol{ + BaseSymbolInformation: protocol.BaseSymbolInformation{Name: "Foo", Kind: protocol.Class}, + Location: protocol.Or_WorkspaceSymbol_location{Value: protocol.Location{URI: fooURI}}, + }, + }, nil + } + findMembersInDocument = func(_ context.Context, _ *lsp.Client, uri protocol.DocumentUri, _, _ string) ([]resolvedSymbol, error) { + return []resolvedSymbol{ + {Name: "baz", Kind: "Method", ContainerName: "Foo", Location: protocol.Location{URI: uri, Range: protocol.Range{Start: protocol.Position{Line: 10}}}}, + }, nil + } + + var rendered []resolvedSymbol + out, err := resolveQualifiedEntries(context.Background(), nil, "Foo.baz", func(sym resolvedSymbol) ([]string, error) { + rendered = append(rendered, sym) + return []string{"entry:" + sym.Name}, nil + }) + assert.NoError(t, err) + assert.Equal(t, []string{"entry:baz"}, out) + assert.Len(t, rendered, 1, "render is invoked once per resolved symbol") + assert.Equal(t, "Foo", rendered[0].ContainerName) + }) + + t.Run("propagates a resolve error", func(t *testing.T) { + origWS := workspaceSymbols + t.Cleanup(func() { workspaceSymbols = origWS }) + + workspaceSymbols = func(context.Context, *lsp.Client, string) ([]protocol.WorkspaceSymbolResult, error) { + return nil, errors.New("symbol query failed") + } + + out, err := resolveQualifiedEntries(context.Background(), nil, "Foo.baz", func(resolvedSymbol) ([]string, error) { + t.Fatal("render must not run when resolution fails") + return nil, nil + }) + assert.Error(t, err) + assert.Nil(t, out) + }) + + t.Run("skips entries whose render fails but keeps the rest", func(t *testing.T) { + origWS := workspaceSymbols + origMembers := findMembersInDocument + t.Cleanup(func() { + workspaceSymbols = origWS + findMembersInDocument = origMembers + }) + + const fooURI = protocol.DocumentUri("file:///Foo.java") + workspaceSymbols = func(_ context.Context, _ *lsp.Client, query string) ([]protocol.WorkspaceSymbolResult, error) { + if query != "Foo" { + return nil, nil + } + return []protocol.WorkspaceSymbolResult{ + &protocol.WorkspaceSymbol{ + BaseSymbolInformation: protocol.BaseSymbolInformation{Name: "Foo", Kind: protocol.Class}, + Location: protocol.Or_WorkspaceSymbol_location{Value: protocol.Location{URI: fooURI}}, + }, + }, nil + } + findMembersInDocument = func(_ context.Context, _ *lsp.Client, uri protocol.DocumentUri, _, _ string) ([]resolvedSymbol, error) { + return []resolvedSymbol{ + {Name: "bad", Location: protocol.Location{URI: uri, Range: protocol.Range{Start: protocol.Position{Line: 1}}}}, + {Name: "good", Location: protocol.Location{URI: uri, Range: protocol.Range{Start: protocol.Position{Line: 2}}}}, + }, nil + } + + out, err := resolveQualifiedEntries(context.Background(), nil, "Foo.baz", func(sym resolvedSymbol) ([]string, error) { + if sym.Name == "bad" { + return nil, errors.New("render failed") + } + return []string{"entry:" + sym.Name}, nil + }) + assert.NoError(t, err, "a per-entry render error must not abort the fallback") + assert.Equal(t, []string{"entry:good"}, out) + }) +} + +func TestScopeMembersToContainer(t *testing.T) { + t.Run("scopes to the requested container in a multi-type file", func(t *testing.T) { + // Foo.baz at line 10, Bar.baz at line 20 — resolving Foo.baz must not + // return Bar.baz. + symbols := tree( + docSym("Foo", protocol.Class, 1, + docSym("baz", protocol.Method, 10), + docSym("other", protocol.Method, 11), + ), + docSym("Bar", protocol.Class, 2, + docSym("baz", protocol.Method, 20), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 1) + assert.Equal(t, "baz", got[0].Name) + assert.Equal(t, "Method", got[0].Kind) + assert.Equal(t, uint32(10), got[0].Location.Range.Start.Line) + }) + + t.Run("matches signature-suffixed method names", func(t *testing.T) { + symbols := tree( + docSym("Foo", protocol.Class, 1, + docSym("baz(String, int)", protocol.Method, 10), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 1) + assert.Equal(t, "baz(String, int)", got[0].Name) + }) + + t.Run("resolves members of a nested type", func(t *testing.T) { + symbols := tree( + docSym("Outer", protocol.Class, 1, + docSym("method", protocol.Method, 5), // Outer.method, must be excluded + docSym("Inner", protocol.Class, 2, + docSym("method", protocol.Method, 15), + ), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Inner", "method") + + assert.Len(t, got, 1) + assert.Equal(t, uint32(15), got[0].Location.Range.Start.Line) + }) + + t.Run("tolerates a generic container suffix", func(t *testing.T) { + symbols := tree( + docSym("Foo", protocol.Class, 1, + docSym("baz", protocol.Method, 10), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 1) + assert.Equal(t, uint32(10), got[0].Location.Range.Start.Line) + }) + + t.Run("scopes flat SymbolInformation by ContainerName", func(t *testing.T) { + symbols := flat( + symInfo("Foo", protocol.Class, "", 1), + symInfo("baz", protocol.Method, "Foo", 10), + symInfo("baz", protocol.Method, "Bar", 20), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 1) + assert.Equal(t, uint32(10), got[0].Location.Range.Start.Line) + }) + + t.Run("scopes flat SymbolInformation with a qualified ContainerName", func(t *testing.T) { + symbols := flat( + symInfo("baz", protocol.Method, "com.example.Foo", 10), + symInfo("baz", protocol.Method, "com.example.Bar", 20), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 1) + assert.Equal(t, uint32(10), got[0].Location.Range.Start.Line) + }) + + t.Run("falls back to document-wide match when container is absent", func(t *testing.T) { + // Container "Zed" does not exist: preserve prior behaviour rather than + // returning nothing. + symbols := tree( + docSym("Foo", protocol.Class, 1, + docSym("baz", protocol.Method, 10), + ), + docSym("Bar", protocol.Class, 2, + docSym("baz", protocol.Method, 20), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Zed", "baz") + + assert.Len(t, got, 2) + }) + + t.Run("returns nothing when the container has no matching member", func(t *testing.T) { + symbols := tree( + docSym("Foo", protocol.Class, 1, + docSym("other", protocol.Method, 10), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Empty(t, got) + }) + + t.Run("flat SymbolInformation keeps its own location URI", func(t *testing.T) { + // A SymbolInformation may point at a different file than the document we + // queried (generated/partial sources); resolvedMember must not overwrite + // its URI with the queried document URI. + const otherURI = protocol.DocumentUri("file:///other.java") + symbols := flat( + protocol.SymbolInformation{ + Name: "baz", + Kind: protocol.Method, + ContainerName: "Foo", + Location: protocol.Location{URI: otherURI, Range: protocol.Range{Start: protocol.Position{Line: 7}}}, + }, + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 1) + assert.Equal(t, otherURI, got[0].Location.URI) + assert.Equal(t, uint32(7), got[0].Location.Range.Start.Line) + }) + + t.Run("collapses duplicate members at the same location", func(t *testing.T) { + symbols := tree( + docSym("Foo", protocol.Class, 1, + docSym("baz", protocol.Method, 10), + docSym("baz", protocol.Method, 10), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 1) + }) + + t.Run("keeps distinct same-named members of nested same-named types", func(t *testing.T) { + // Outer Foo.baz (line 10) and a nested type also named Foo with its own + // baz (line 20) are genuinely distinct symbols and must both survive. + symbols := tree( + docSym("Foo", protocol.Class, 1, + docSym("baz", protocol.Method, 10), + docSym("Foo", protocol.Class, 2, + docSym("baz", protocol.Method, 20), + ), + ), + ) + + got := scopeMembersToContainer(symbols, testURI, "Foo", "baz") + + assert.Len(t, got, 2) + }) +} diff --git a/internal/tools/utilities.go b/internal/tools/utilities.go index e5beb285..da2d571f 100644 --- a/internal/tools/utilities.go +++ b/internal/tools/utilities.go @@ -10,8 +10,32 @@ import ( "github.com/isaacphi/mcp-language-server/internal/protocol" ) +// uriToPath converts a DocumentUri to a filesystem path. It is a non-panicking +// wrapper around DocumentUri.Path(), which panics on non-"file://" URIs (e.g. +// the "jdt://" URIs jdtls returns for symbols inside JARs). Returning an error +// lets callers degrade gracefully instead of crashing the tool. Behavior for +// ordinary file:// URIs is identical to calling Path() directly, including the +// percent-decoding that makes os.ReadFile find paths with encoded characters. +func uriToPath(uri protocol.DocumentUri) (path string, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("unsupported document URI %q: %v", uri, r) + } + }() + // An empty URI yields ("", nil): Path() maps "" to "" without panicking, + // and an absent URI is a benign no-op for callers, not an error. + return uri.Path(), nil +} + func ExtractTextFromLocation(loc protocol.Location) (string, error) { - path := strings.TrimPrefix(string(loc.URI), "file://") + // Use uriToPath rather than trimming the "file://" prefix: it + // percent-decodes the path (e.g. "%20" -> space) so os.ReadFile finds + // files whose paths contain encoded characters, consistent with the + // definition and references tools, and degrades gracefully on non-file URIs. + path, err := uriToPath(loc.URI) + if err != nil { + return "", err + } content, err := os.ReadFile(path) if err != nil { @@ -68,6 +92,30 @@ func ExtractTextFromLocation(loc protocol.Location) (string, error) { return result.String(), nil } +// extractLineText returns the full text of the given 0-indexed line in the +// file referenced by uri. It builds a single-line Location range (start and end +// on the same line) covering the whole line, which works for every line +// including the last one in the file. Returns an error if the line is out of +// range or the file cannot be read. +func extractLineText(uri protocol.DocumentUri, line uint32) (string, error) { + path, err := uriToPath(uri) + if err != nil { + return "", err + } + + content, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read file: %w", err) + } + + lines := strings.Split(string(content), "\n") + idx := int(line) + if idx < 0 || idx >= len(lines) { + return "", fmt.Errorf("line %d out of range (%d lines)", line, len(lines)) + } + return lines[idx], nil +} + func containsPosition(r protocol.Range, p protocol.Position) bool { if r.Start.Line > p.Line || r.End.Line < p.Line { return false @@ -84,8 +132,11 @@ func containsPosition(r protocol.Range, p protocol.Position) bool { // addLineNumbers adds line numbers to each line of text with proper padding, starting from startLine func addLineNumbers(text string, startLine int) string { lines := strings.Split(text, "\n") - // Calculate padding width based on the number of digits in the last line number - lastLineNum := startLine + len(lines) + // Calculate padding width based on the number of digits in the last line + // number. The last line printed is startLine+len(lines)-1 (the loop below + // prints startLine+i for i in [0, len(lines))), so subtract one here to + // avoid over-padding by a digit at power-of-ten boundaries. + lastLineNum := startLine + len(lines) - 1 padding := len(strconv.Itoa(lastLineNum)) var result strings.Builder diff --git a/internal/tools/utilities_test.go b/internal/tools/utilities_test.go index 5b4959f6..c0a2cff2 100644 --- a/internal/tools/utilities_test.go +++ b/internal/tools/utilities_test.go @@ -191,6 +191,31 @@ func TestExtractTextFromLocation_FileError(t *testing.T) { assert.Error(t, err) } +func TestExtractLineText(t *testing.T) { + dir := t.TempDir() + // No trailing newline: the final line is the last index after splitting. + path := dir + "/sample.txt" + if err := os.WriteFile(path, []byte("first\nsecond\nlast"), 0o600); err != nil { + t.Fatalf("write temp file: %v", err) + } + uri := protocol.DocumentUri("file://" + path) + + // Middle line. + got, err := extractLineText(uri, 1) + assert.NoError(t, err) + assert.Equal(t, "second", got) + + // Last line: this is the case that previously failed because the hover + // fallback spanned to Line+1 and tripped the bounds check. + got, err = extractLineText(uri, 2) + assert.NoError(t, err) + assert.Equal(t, "last", got) + + // Out of range. + _, err = extractLineText(uri, 3) + assert.Error(t, err) +} + func TestContainsPosition(t *testing.T) { testCases := []struct { name string @@ -330,6 +355,15 @@ func TestAddLineNumbers(t *testing.T) { startLine: 1, expected: "1|\n", }, + { + // Last printed line number is 99 (startLine 97 + 3 lines - 1), + // so padding must be 2 wide, not 3. Guards against an off-by-one + // where lastLineNum was computed as startLine+len(lines). + name: "Padding does not over-pad at power-of-ten boundary", + text: "line1\nline2\nline3", + startLine: 97, + expected: "97|line1\n98|line2\n99|line3\n", + }, } for _, tc := range testCases {