diff --git a/core/ast/function_decl.go b/core/ast/function_decl.go index de51186..61bd1e3 100644 --- a/core/ast/function_decl.go +++ b/core/ast/function_decl.go @@ -2,19 +2,218 @@ package ast import sitter "github.com/smacker/go-tree-sitter" -// TODO: Does this makes sense? +// FunctionType represents the type/category of function +type FunctionType string + +const ( + // FunctionTypeFunction represents a regular function + FunctionTypeFunction FunctionType = "function" + + // FunctionTypeMethod represents a class method + FunctionTypeMethod FunctionType = "method" + + // FunctionTypeConstructor represents a constructor function + FunctionTypeConstructor FunctionType = "constructor" + + // FunctionTypeStaticMethod represents a static class method + FunctionTypeStaticMethod FunctionType = "static_method" + + // FunctionTypeAsync represents an async function + FunctionTypeAsync FunctionType = "async" + + // FunctionTypeArrow represents an arrow function (JavaScript) + FunctionTypeArrow FunctionType = "arrow" + + // FunctionTypeUnknown represents unknown function type + FunctionTypeUnknown FunctionType = "unknown" +) + +// FunctionDeclarationNode represents a function declaration with comprehensive metadata +// This is a language agnostic representation of a function declaration. +// Not all attributes may be present in all languages type FunctionDeclarationNode struct { Node - // Name of the function - FunctionNameNode *sitter.Node + // Core function information + functionNameNode *sitter.Node + + // Function signature + functionParameterNodes []*sitter.Node + functionReturnTypeNode *sitter.Node - // Function parameters - FunctionParameterNodes []*sitter.Node + // Function body and implementation + functionBodyNode *sitter.Node + + // Language-specific decorators/annotations (Python @decorator, Java @Annotation) + decoratorNodes []*sitter.Node + + // Function metadata + functionType FunctionType + accessModifier AccessModifier + isAbstract bool + isStatic bool + isAsync bool + + // Parent class context (for methods) + parentClassName string +} + +// NewFunctionDeclarationNode creates a new FunctionDeclarationNode instance +func NewFunctionDeclarationNode(content Content) *FunctionDeclarationNode { + return &FunctionDeclarationNode{ + Node: Node{content}, + functionParameterNodes: []*sitter.Node{}, + decoratorNodes: []*sitter.Node{}, + functionType: FunctionTypeFunction, // Default to regular function + accessModifier: AccessModifierPublic, // Default to public access + isAbstract: false, + isStatic: false, + isAsync: false, + } +} + +// FunctionName returns the name of the function +func (f *FunctionDeclarationNode) FunctionName() string { + return f.contentForNode(f.functionNameNode) +} + +// Parameters returns the content of all parameter nodes +func (f *FunctionDeclarationNode) Parameters() []string { + var parameters []string + for _, paramNode := range f.functionParameterNodes { + if param := f.contentForNode(paramNode); param != "" { + parameters = append(parameters, param) + } + } + return parameters +} + +// ReturnType returns the return type of the function if available +func (f *FunctionDeclarationNode) ReturnType() string { + return f.contentForNode(f.functionReturnTypeNode) +} + +// Body returns the function body content +func (f *FunctionDeclarationNode) Body() string { + return f.contentForNode(f.functionBodyNode) +} + +// Decorators returns the content of all decorator/annotation nodes +func (f *FunctionDeclarationNode) Decorators() []string { + var decorators []string + for _, decoratorNode := range f.decoratorNodes { + if decorator := f.contentForNode(decoratorNode); decorator != "" { + decorators = append(decorators, decorator) + } + } + return decorators +} + +// HasDecorators returns true if the function has decorators/annotations +func (f *FunctionDeclarationNode) HasDecorators() bool { + return len(f.decoratorNodes) > 0 +} + +// IsMethod returns true if this function is a class method +func (f *FunctionDeclarationNode) IsMethod() bool { + return f.functionType == FunctionTypeMethod || + f.functionType == FunctionTypeConstructor || + f.functionType == FunctionTypeStaticMethod +} + +// IsConstructor returns true if this function is a constructor +func (f *FunctionDeclarationNode) IsConstructor() bool { + return f.functionType == FunctionTypeConstructor +} + +// GetFunctionType returns the type/category of the function +func (f *FunctionDeclarationNode) GetFunctionType() FunctionType { + return f.functionType +} - // Function return type - FunctionReturnTypeNode *sitter.Node +// GetAccessModifier returns the access modifier of the function +func (f *FunctionDeclarationNode) GetAccessModifier() AccessModifier { + return f.accessModifier +} + +// GetParentClassName returns the name of the parent class (for methods) +func (f *FunctionDeclarationNode) GetParentClassName() string { + return f.parentClassName +} + +// IsAbstract returns true if the function is abstract +func (f *FunctionDeclarationNode) IsAbstract() bool { + return f.isAbstract +} + +// IsStatic returns true if the function is static +func (f *FunctionDeclarationNode) IsStatic() bool { + return f.isStatic +} + +// IsAsync returns true if the function is async +func (f *FunctionDeclarationNode) IsAsync() bool { + return f.isAsync +} + +// Setter methods + +// SetFunctionNameNode sets the function name node +func (f *FunctionDeclarationNode) SetFunctionNameNode(node *sitter.Node) { + f.functionNameNode = node +} + +// SetFunctionParameterNodes sets all parameter nodes +func (f *FunctionDeclarationNode) SetFunctionParameterNodes(nodes []*sitter.Node) { + f.functionParameterNodes = nodes +} + +// AddFunctionParameterNode adds a parameter node +func (f *FunctionDeclarationNode) AddFunctionParameterNode(node *sitter.Node) { + f.functionParameterNodes = append(f.functionParameterNodes, node) +} + +// SetFunctionReturnTypeNode sets the return type node +func (f *FunctionDeclarationNode) SetFunctionReturnTypeNode(node *sitter.Node) { + f.functionReturnTypeNode = node +} + +// SetFunctionBodyNode sets the function body node +func (f *FunctionDeclarationNode) SetFunctionBodyNode(node *sitter.Node) { + f.functionBodyNode = node +} + +// AddDecoratorNode adds a decorator/annotation node +func (f *FunctionDeclarationNode) AddDecoratorNode(node *sitter.Node) { + f.decoratorNodes = append(f.decoratorNodes, node) +} + +// SetFunctionType sets the function type/category +func (f *FunctionDeclarationNode) SetFunctionType(funcType FunctionType) { + f.functionType = funcType +} + +// SetAccessModifier sets the access modifier +func (f *FunctionDeclarationNode) SetAccessModifier(modifier AccessModifier) { + f.accessModifier = modifier +} + +// SetIsAbstract sets the abstract flag +func (f *FunctionDeclarationNode) SetIsAbstract(abstract bool) { + f.isAbstract = abstract +} + +// SetIsStatic sets the static flag +func (f *FunctionDeclarationNode) SetIsStatic(static bool) { + f.isStatic = static +} + +// SetIsAsync sets the async flag +func (f *FunctionDeclarationNode) SetIsAsync(async bool) { + f.isAsync = async +} - // Function body - FunctionBodyNode *sitter.Node +// SetParentClassName sets the parent class name (for methods) +func (f *FunctionDeclarationNode) SetParentClassName(className string) { + f.parentClassName = className } diff --git a/core/ast/node.go b/core/ast/node.go index dcfad00..2ec9b0a 100644 --- a/core/ast/node.go +++ b/core/ast/node.go @@ -23,3 +23,41 @@ func (n *Node) contentForNode(node *sitter.Node) string { return node.Content(*n.content) } + +// NodePosition represents position information for a Tree-Sitter node +type NodePosition struct { + StartByte uint32 + EndByte uint32 + StartLine uint32 + EndLine uint32 + StartColumn uint32 + EndColumn uint32 +} + +// GetNodePosition extracts position information from a Tree-Sitter node +func GetNodePosition(node *sitter.Node) NodePosition { + if node == nil { + return NodePosition{} + } + + startPoint := node.StartPoint() + endPoint := node.EndPoint() + + return NodePosition{ + StartByte: node.StartByte(), + EndByte: node.EndByte(), + StartLine: startPoint.Row + 1, // Tree-sitter uses 0-based rows + EndLine: endPoint.Row + 1, // Tree-sitter uses 0-based rows + StartColumn: startPoint.Column + 1, // Tree-sitter uses 0-based columns + EndColumn: endPoint.Column + 1, // Tree-sitter uses 0-based columns + } +} + +// GetNodeType returns the type of the Tree-Sitter node +func GetNodeType(node *sitter.Node) string { + if node == nil { + return "" + } + + return node.Type() +} diff --git a/core/language.go b/core/language.go index 470689b..6c920db 100644 --- a/core/language.go +++ b/core/language.go @@ -12,6 +12,10 @@ type LanguageResolvers interface { // ResolveImports returns a list of import statements // identified from the parse tree ResolveImports(tree ParseTree) ([]*ast.ImportNode, error) + + // ResolveFunctions returns a list of function declarations + // identified from the parse tree + ResolveFunctions(tree ParseTree) ([]*ast.FunctionDeclarationNode, error) } // ObjectOrientedLanguageResolvers define the additional contract diff --git a/docs/grammar.md b/docs/grammar.md new file mode 100644 index 0000000..9a99f9d --- /dev/null +++ b/docs/grammar.md @@ -0,0 +1,9 @@ +# Grammar + +Tree-Sitter is used as the parser library. Specifically, + +- [Java](https://raw.githubusercontent.com/tree-sitter/tree-sitter-java/refs/heads/master/grammar.js) +- [JavaScript](https://raw.githubusercontent.com/tree-sitter/tree-sitter-javascript/refs/heads/master/grammar.js) +- [TypeScript](https://raw.githubusercontent.com/tree-sitter/tree-sitter-typescript/refs/heads/master/common/define-grammar.js) +- [Python](https://raw.githubusercontent.com/tree-sitter/tree-sitter-python/refs/heads/master/grammar.js) +- [Go](https://raw.githubusercontent.com/tree-sitter/tree-sitter-go/refs/heads/master/grammar.js) diff --git a/lang/fixtures/Functions.java b/lang/fixtures/Functions.java new file mode 100644 index 0000000..e2d9ee3 --- /dev/null +++ b/lang/fixtures/Functions.java @@ -0,0 +1,44 @@ +package lang.fixtures; + +import java.util.List; + +public class MyClassWithFunctions { + + private String field; + + // Constructor + public MyClassWithFunctions(String field) { + this.field = field; + } + + // Public method + public String publicMethod(int i) { + return "public"; + } + + // Protected method + protected void protectedMethod() { + } + + // Private method + private boolean privateMethod(String s) { + return s.isEmpty(); + } + + // Static method + public static void staticMethod() { + } + + // Method with annotation + @Override + public String toString() { + return field; + } +} + +// A simple function (static method in a container class in Java) +class TestFunctions { + public static int add(int a, int b) { + return a + b; + } +} diff --git a/lang/fixtures/functions.go b/lang/fixtures/functions.go new file mode 100644 index 0000000..af61fca --- /dev/null +++ b/lang/fixtures/functions.go @@ -0,0 +1,77 @@ +//go:build exclude + +package fixtures + +import "fmt" + +// A simple function +func simpleFunction() { + fmt.Println("hello") +} + +// A function with parameters and return value +func functionWithArgs(a int, b string) string { + return fmt.Sprintf("%s: %d", b, a) +} + +// A struct for methods +type MyStruct struct { + val int +} + +// A method on MyStruct +func (s *MyStruct) MyMethod(p int) int { + return s.val + p +} + +// A private function (package-level) +func privateFunction() { + // not exported +} + +// Additional test cases for more comprehensive testing + +// An exported function starting with uppercase +func ExportedFunction() { + // exported (public) +} + +// A private function starting with lowercase +func unexportedFunction() { + // unexported (package-private) +} + +// A function with underscore (should be package-private) +func _underscoreFunction() { + // unexported (package-private) +} + +// Another struct for more method testing +type myPrivateStruct struct { + data string +} + +// An exported method on private struct +func (m *myPrivateStruct) ExportedMethod() string { + return m.data +} + +// An unexported method on private struct +func (m *myPrivateStruct) unexportedMethod() string { + return m.data +} + +// A public struct +type PublicStruct struct { + Value int +} + +// An exported method on public struct +func (p *PublicStruct) PublicMethod() int { + return p.Value +} + +// An unexported method on public struct +func (p *PublicStruct) privateMethod() int { + return p.Value * 2 +} diff --git a/lang/fixtures/functions.js b/lang/fixtures/functions.js new file mode 100644 index 0000000..89e9c5d --- /dev/null +++ b/lang/fixtures/functions.js @@ -0,0 +1,50 @@ +// A simple function declaration +function declaredFunction(a, b) { + return a + b; +} + +// A function expression +const expressionFunction = function(x) { + return x * x; +}; + +// An arrow function +const arrowFunction = (y) => { + return y / 2; +}; + +// An async function +async function asyncFunction() { + return Promise.resolve('done'); +} + +// A class with methods +class MyClass { + constructor(name) { + this.name = name; + } + + myMethod(value) { + return `${this.name}: ${value}`; + } + + static staticMethod() { + return "static"; + } + + get myProperty() { + return this.name; + } +} + +// A decorated method (assuming decorators are enabled) +function myDecorator(target, key, descriptor) { + // no-op +} + +class ClassWithDecorator { + @myDecorator + decoratedMethod() { + return 'decorated'; + } +} diff --git a/lang/fixtures/functions.py b/lang/fixtures/functions.py new file mode 100644 index 0000000..12030d6 --- /dev/null +++ b/lang/fixtures/functions.py @@ -0,0 +1,46 @@ +import asyncio + +# A simple function +def simple_function(): + pass + +# A function with arguments and type hints +def function_with_args(a: int, b: str) -> str: + return f"{b}: {a}" + +# An async function +async def my_async_function(): + await asyncio.sleep(1) + return "done" + +# A class with methods +class MyClass: + def __init__(self, name): + self.name = name + + def instance_method(self, value): + return f"{self.name}: {value}" + + @staticmethod + def static_method(): + return "static" + + @classmethod + def class_method(cls): + return "class_method" + +# A decorated function +def my_decorator(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + +@my_decorator +def decorated_function(): + return "decorated" + +# A nested function +def outer_function(): + def inner_function(): + pass + return inner_function diff --git a/lang/go_resolvers.go b/lang/go_resolvers.go index 535639a..f50331d 100644 --- a/lang/go_resolvers.go +++ b/lang/go_resolvers.go @@ -16,16 +16,16 @@ type goResolvers struct { var _ core.LanguageResolvers = (*goResolvers)(nil) const goWholeModuleImportQuery = ` - (import_declaration - (import_spec + (import_declaration + (import_spec name: (package_identifier)? @module_alias name: (blank_identifier)? @blank_identifier name: (dot)? @dot_identifier path: (interpreted_string_literal) @module_name)) - (import_declaration - (import_spec_list - (import_spec + (import_declaration + (import_spec_list + (import_spec name: (package_identifier)? @module_alias name: (blank_identifier)? @blank_identifier name: (dot)? @dot_identifier @@ -77,3 +77,302 @@ func (r *goResolvers) ResolveImports(tree core.ParseTree) ([]*ast.ImportNode, er return imports, err } + +// Tree-Sitter queries for Go function definitions based on actual grammar +const goFunctionDefinitionQuery = ` + (function_declaration + name: (identifier) @function_name + parameters: (parameter_list) @function_params + body: (block) @function_body) +` + +// Based on Go grammar: method_declaration has receiver field with parameter_list and name field with field_identifier +const goMethodDefinitionQuery = ` + (method_declaration + receiver: (parameter_list) @receiver + name: (field_identifier) @method_name + parameters: (parameter_list) @method_params + body: (block) @method_body) +` + +// ResolveFunctions extracts function declarations from Go parse tree +func (r *goResolvers) ResolveFunctions(tree core.ParseTree) ([]*ast.FunctionDeclarationNode, error) { + data, err := tree.Data() + if err != nil { + return nil, fmt.Errorf("failed to get data from parse tree: %w", err) + } + + var functions []*ast.FunctionDeclarationNode + functionMap := make(map[string]*ast.FunctionDeclarationNode) // To avoid duplicates + + // Extract regular function declarations + err = r.extractGoFunctions(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract Go functions: %w", err) + } + + // Extract method declarations + err = r.extractGoMethods(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract Go methods: %w", err) + } + + // Convert map to slice + for _, function := range functionMap { + functions = append(functions, function) + } + + return functions, nil +} + +// Helper methods for Go function extraction + +func (r *goResolvers) extractGoFunctions(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(goFunctionDefinitionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 3 { + return nil + } + + var functionNameNode, paramsNode, bodyNode *sitter.Node + + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "identifier": + functionNameNode = capture.Node + case "parameter_list": + paramsNode = capture.Node + case "block": + bodyNode = capture.Node + } + } + + if functionNameNode == nil { + return nil + } + + // Validate function name + functionName := functionNameNode.Content(*data) + if !r.isValidGoIdentifier(functionName) { + return nil // Skip invalid identifiers + } + + functionKey := r.generateGoFunctionKey(functionNameNode, "", *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(functionNameNode) + functionNode.SetFunctionType(ast.FunctionTypeFunction) + + // Set parameters + if paramsNode != nil { + paramNodes := r.extractGoParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set function body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + // Go access modifiers: Public if starts with uppercase, Package if lowercase + if r.isExportedGoIdentifier(functionName) { + functionNode.SetAccessModifier(ast.AccessModifierPublic) + } else { + functionNode.SetAccessModifier(ast.AccessModifierPackage) + } + + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +func (r *goResolvers) extractGoMethods(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(goMethodDefinitionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 4 { + return nil + } + + var receiverNode, methodNameNode, paramsNode, bodyNode *sitter.Node + + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "parameter_list": + if receiverNode == nil { + receiverNode = capture.Node + } else { + paramsNode = capture.Node + } + case "field_identifier": + methodNameNode = capture.Node + case "block": + bodyNode = capture.Node + } + } + + if methodNameNode == nil || receiverNode == nil { + return nil + } + + // Validate method name + methodName := methodNameNode.Content(*data) + if !r.isValidGoIdentifier(methodName) { + return nil // Skip invalid identifiers + } + + // Extract receiver type name from the receiver parameter list + receiverTypeName := r.extractReceiverTypeName(receiverNode, *data) + if receiverTypeName == "" { + return nil + } + + // Validate receiver type name + if !r.isValidGoIdentifier(receiverTypeName) { + return nil // Skip invalid receiver types + } + + functionKey := r.generateGoFunctionKey(methodNameNode, receiverTypeName, *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(methodNameNode) + functionNode.SetFunctionType(ast.FunctionTypeMethod) + functionNode.SetParentClassName(receiverTypeName) + + // Set parameters + if paramsNode != nil { + paramNodes := r.extractGoParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set method body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + // Go methods are public if they start with uppercase + if r.isExportedGoIdentifier(methodName) { + functionNode.SetAccessModifier(ast.AccessModifierPublic) + } else { + functionNode.SetAccessModifier(ast.AccessModifierPackage) + } + + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +// Helper methods for Go function processing + +func (r *goResolvers) extractGoParameterNodes(parametersNode *sitter.Node) []*sitter.Node { + var paramNodes []*sitter.Node + + if parametersNode == nil { + return paramNodes + } + + for i := 0; i < int(parametersNode.ChildCount()); i++ { + child := parametersNode.Child(i) + if child.Type() == "parameter_declaration" || child.Type() == "variadic_parameter_declaration" { + paramNodes = append(paramNodes, child) + } + } + + return paramNodes +} + +func (r *goResolvers) extractReceiverTypeName(receiverNode *sitter.Node, data []byte) string { + if receiverNode == nil { + return "" + } + + // Look for type identifier in the receiver parameter list + for i := 0; i < int(receiverNode.ChildCount()); i++ { + child := receiverNode.Child(i) + if child.Type() == "parameter_declaration" { + // Find the type part of the parameter declaration + for j := 0; j < int(child.ChildCount()); j++ { + grandchild := child.Child(j) + if grandchild.Type() == "type_identifier" || + grandchild.Type() == "pointer_type" { + return r.extractTypeNameFromNode(grandchild, data) + } + } + } + } + + return "" +} + +func (r *goResolvers) extractTypeNameFromNode(typeNode *sitter.Node, data []byte) string { + if typeNode == nil { + return "" + } + + if typeNode.Type() == "type_identifier" { + return typeNode.Content(data) + } else if typeNode.Type() == "pointer_type" { + // For pointer types, get the underlying type + for i := 0; i < int(typeNode.ChildCount()); i++ { + child := typeNode.Child(i) + if child.Type() == "type_identifier" { + return child.Content(data) + } + } + } + + return "" +} + +func (r *goResolvers) generateGoFunctionKey(functionNameNode *sitter.Node, receiverType string, data []byte) string { + functionName := functionNameNode.Content(data) + + if receiverType != "" { + return receiverType + "." + functionName + } + + // Add line number to distinguish functions with same name in different scopes + lineNumber := functionNameNode.StartPoint().Row + return fmt.Sprintf("%s:%d", functionName, lineNumber) +} + +// isExportedGoIdentifier checks if a Go identifier is exported (public) based on Go naming conventions. +// In Go, an identifier is exported if it starts with an uppercase letter. +func (r *goResolvers) isExportedGoIdentifier(identifier string) bool { + if len(identifier) == 0 { + return false + } + + // Check if the first character is an uppercase letter + firstChar := identifier[0] + return firstChar >= 'A' && firstChar <= 'Z' +} + +// isValidGoIdentifier validates if a string is a valid Go identifier. +// A valid Go identifier starts with a letter or underscore, followed by letters, digits, or underscores. +func (r *goResolvers) isValidGoIdentifier(identifier string) bool { + if len(identifier) == 0 { + return false + } + + // First character must be a letter or underscore + firstChar := identifier[0] + if (firstChar < 'a' || firstChar > 'z') && (firstChar < 'A' || firstChar > 'Z') && firstChar != '_' { + return false + } + + // Remaining characters must be letters, digits, or underscores + for i := 1; i < len(identifier); i++ { + char := identifier[i] + if (char < 'a' || char > 'z') && (char < 'A' || char > 'Z') && (char < '0' || char > '9') && char != '_' { + return false + } + } + + return true +} diff --git a/lang/go_resolvers_test.go b/lang/go_resolvers_test.go index e11d2e8..4c545d6 100644 --- a/lang/go_resolvers_test.go +++ b/lang/go_resolvers_test.go @@ -2,6 +2,7 @@ package lang_test import ( "context" + "fmt" "testing" "github.com/safedep/code/core" @@ -30,6 +31,22 @@ var goImportExpectations = []ImportExpectations{ }, } +var goFunctionExpectations = map[string][]string{ + "fixtures/functions.go": { + "FunctionDeclarationNode{Name: simpleFunction, Type: function, Access: package, ParentClass: }", + "FunctionDeclarationNode{Name: functionWithArgs, Type: function, Access: package, ParentClass: }", + "FunctionDeclarationNode{Name: MyMethod, Type: method, Access: public, ParentClass: MyStruct}", + "FunctionDeclarationNode{Name: privateFunction, Type: function, Access: package, ParentClass: }", + "FunctionDeclarationNode{Name: ExportedFunction, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: unexportedFunction, Type: function, Access: package, ParentClass: }", + "FunctionDeclarationNode{Name: _underscoreFunction, Type: function, Access: package, ParentClass: }", + "FunctionDeclarationNode{Name: ExportedMethod, Type: method, Access: public, ParentClass: myPrivateStruct}", + "FunctionDeclarationNode{Name: unexportedMethod, Type: method, Access: package, ParentClass: myPrivateStruct}", + "FunctionDeclarationNode{Name: PublicMethod, Type: method, Access: public, ParentClass: PublicStruct}", + "FunctionDeclarationNode{Name: privateMethod, Type: method, Access: package, ParentClass: PublicStruct}", + }, +} + func TestGoLanguageResolvers(t *testing.T) { t.Run("ResolversExists", func(t *testing.T) { l, err := lang.NewGoLanguage() @@ -39,9 +56,6 @@ func TestGoLanguageResolvers(t *testing.T) { }) t.Run("ResolveImports", func(t *testing.T) { - l, err := lang.NewGoLanguage() - assert.NoError(t, err) - importExpectationsMapper := make(map[string][]string) importFilePaths := []string{} for _, ie := range goImportExpectations { @@ -64,7 +78,7 @@ func TestGoLanguageResolvers(t *testing.T) { parseTree, err := fileParser.Parse(context.Background(), f) assert.NoError(t, err) - imports, err := l.Resolvers().ResolveImports(parseTree) + imports, err := goLanguage.Resolvers().ResolveImports(parseTree) assert.NoError(t, err) expectedImports, ok := importExpectationsMapper[f.Name()] @@ -79,4 +93,45 @@ func TestGoLanguageResolvers(t *testing.T) { }) assert.NoError(t, err) }) + + t.Run("ResolveFunctions", func(t *testing.T) { + var filePaths []string + for path := range goFunctionExpectations { + filePaths = append(filePaths, path) + } + + goLanguage, err := lang.NewGoLanguage() + assert.NoError(t, err) + + fileParser, err := parser.NewParser([]core.Language{goLanguage}) + assert.NoError(t, err) + + fileSystem, err := fs.NewLocalFileSystem(fs.LocalFileSystemConfig{ + AppDirectories: filePaths, + }) + assert.NoError(t, err) + + err = fileSystem.EnumerateApp(context.Background(), func(f core.File) error { + parseTree, err := fileParser.Parse(context.Background(), f) + assert.NoError(t, err) + + functions, err := goLanguage.Resolvers().ResolveFunctions(parseTree) + assert.NoError(t, err) + + expectedFunctions, ok := goFunctionExpectations[f.Name()] + assert.True(t, ok) + + var foundFunctions []string + for _, fun := range functions { + foundFunctions = append(foundFunctions, + fmt.Sprintf("FunctionDeclarationNode{Name: %s, Type: %s, Access: %s, ParentClass: %s}", + fun.FunctionName(), fun.GetFunctionType(), fun.GetAccessModifier(), fun.GetParentClassName())) + } + + assert.ElementsMatch(t, expectedFunctions, foundFunctions) + + return nil + }) + assert.NoError(t, err) + }) } diff --git a/lang/java_resolvers.go b/lang/java_resolvers.go index 00b340c..62c5808 100644 --- a/lang/java_resolvers.go +++ b/lang/java_resolvers.go @@ -79,7 +79,6 @@ const javaClassFieldQuery = ` name: (identifier) @field_name)) @field_def)) ` - const javaInheritanceQuery = ` (class_declaration name: (identifier) @class_name @@ -232,14 +231,15 @@ func (r *javaResolvers) ResolveInheritance(tree core.ParseTree) (*ast.Inheritanc // Helper methods for Java class extraction -func (r *javaResolvers) extractClassDefinitions(data *[]byte, tree core.ParseTree, classes *[]*ast.ClassDeclarationNode, classMap map[string]*ast.ClassDeclarationNode) error { +func (r *javaResolvers) extractClassDefinitions(data *[]byte, tree core.ParseTree, + classes *[]*ast.ClassDeclarationNode, classMap map[string]*ast.ClassDeclarationNode) error { queryRequestItems := []ts.QueryItem{ ts.NewQueryItem(javaClassDefinitionQuery, func(m *sitter.QueryMatch) error { classNode := ast.NewClassDeclarationNode(ast.ToContent(*data)) - + var className string var modifiersNode *sitter.Node - + for _, capture := range m.Captures { switch capture.Node.Type() { case "identifier": @@ -255,16 +255,17 @@ func (r *javaResolvers) extractClassDefinitions(data *[]byte, tree core.ParseTre modifiersNode = capture.Node } } - + if className != "" { // Check for abstract modifier and extract annotations if modifiersNode != nil { if r.hasAbstractModifier(modifiersNode, *data) { classNode.SetIsAbstract(true) } + r.extractAnnotationsFromModifiers(modifiersNode, classNode) } - + classNode.SetAccessModifier(r.extractAccessModifier(m)) *classes = append(*classes, classNode) classMap[className] = classNode @@ -272,14 +273,14 @@ func (r *javaResolvers) extractClassDefinitions(data *[]byte, tree core.ParseTre return nil }), - + // Also handle interfaces as classes (Java interfaces are class-like) ts.NewQueryItem(javaInterfaceDefinitionQuery, func(m *sitter.QueryMatch) error { classNode := ast.NewClassDeclarationNode(ast.ToContent(*data)) classNode.SetIsAbstract(true) // Interfaces are abstract by nature - + var className string - + for _, capture := range m.Captures { switch capture.Node.Type() { case "identifier": @@ -292,7 +293,7 @@ func (r *javaResolvers) extractClassDefinitions(data *[]byte, tree core.ParseTre classNode.AddBaseClassNode(capture.Node) } } - + if className != "" { classNode.SetAccessModifier(ast.AccessModifierPublic) // Interfaces are public by default *classes = append(*classes, classNode) @@ -310,7 +311,7 @@ func (r *javaResolvers) extractClassMethods(data *[]byte, tree core.ParseTree, c queryRequestItems := []ts.QueryItem{ ts.NewQueryItem(javaClassMethodQuery, func(m *sitter.QueryMatch) error { var methodNameNode, methodDefNode *sitter.Node - + for _, capture := range m.Captures { if capture.Node.Type() == "identifier" { methodNameNode = capture.Node @@ -318,7 +319,7 @@ func (r *javaResolvers) extractClassMethods(data *[]byte, tree core.ParseTree, c methodDefNode = capture.Node } } - + if methodNameNode == nil || methodDefNode == nil { return nil } @@ -335,12 +336,12 @@ func (r *javaResolvers) extractClassMethods(data *[]byte, tree core.ParseTree, c return nil }), - + // Handle constructors separately ts.NewQueryItem(javaClassConstructorQuery, func(m *sitter.QueryMatch) error { var constructorDefNode *sitter.Node var className string - + for _, capture := range m.Captures { if capture.Node.Type() == "identifier" { className = capture.Node.Content(*data) @@ -348,7 +349,7 @@ func (r *javaResolvers) extractClassMethods(data *[]byte, tree core.ParseTree, c constructorDefNode = capture.Node } } - + if className != "" && constructorDefNode != nil { if classNode, exists := classMap[className]; exists { classNode.SetConstructorNode(constructorDefNode) @@ -366,14 +367,14 @@ func (r *javaResolvers) extractClassFields(data *[]byte, tree core.ParseTree, cl queryRequestItems := []ts.QueryItem{ ts.NewQueryItem(javaClassFieldQuery, func(m *sitter.QueryMatch) error { var fieldDefNode *sitter.Node - + for _, capture := range m.Captures { if capture.Node.Type() == "field_declaration" { fieldDefNode = capture.Node break } } - + if fieldDefNode == nil { return nil } @@ -400,7 +401,7 @@ func (r *javaResolvers) extractClassAnnotations(data *[]byte, tree core.ParseTre ts.NewQueryItem(javaClassAnnotationQuery, func(m *sitter.QueryMatch) error { var annotationNode *sitter.Node var className string - + for _, capture := range m.Captures { if capture.Node.Type() == "annotation" || capture.Node.Type() == "marker_annotation" { annotationNode = capture.Node @@ -408,7 +409,7 @@ func (r *javaResolvers) extractClassAnnotations(data *[]byte, tree core.ParseTre className = capture.Node.Content(*data) } } - + if className != "" && annotationNode != nil { if classNode, exists := classMap[className]; exists { classNode.AddDecoratorNode(annotationNode) @@ -438,18 +439,17 @@ func (r *javaResolvers) findParentClassName(node *sitter.Node, data []byte) stri return "" } -func (r *javaResolvers) extractAccessModifier(m *sitter.QueryMatch) ast.AccessModifier { +func (r *javaResolvers) extractAccessModifier(_ *sitter.QueryMatch) ast.AccessModifier { // In Java, we need to look for modifiers in the parent nodes // For now, default to public (can be enhanced later) return ast.AccessModifierPublic } - func (r *javaResolvers) hasAbstractModifier(modifiersNode *sitter.Node, data []byte) bool { if modifiersNode == nil { return false } - + // Traverse child nodes looking for "abstract" keyword for i := 0; i < int(modifiersNode.ChildCount()); i++ { child := modifiersNode.Child(i) @@ -464,7 +464,7 @@ func (r *javaResolvers) extractAnnotationsFromModifiers(modifiersNode *sitter.No if modifiersNode == nil { return } - + // Traverse child nodes looking for annotations for i := 0; i < int(modifiersNode.ChildCount()); i++ { child := modifiersNode.Child(i) @@ -475,3 +475,308 @@ func (r *javaResolvers) extractAnnotationsFromModifiers(modifiersNode *sitter.No } } } + +// Tree-Sitter queries for Java function definitions +const javaFunctionDefinitionQuery = ` + (method_declaration + (modifiers)? @modifiers + type: (_) @return_type + name: (identifier) @function_name + parameters: (formal_parameters) @function_params + body: (block) @function_body) +` + +const javaConstructorDefinitionQuery = ` + (constructor_declaration + (modifiers)? @modifiers + name: (identifier) @constructor_name + parameters: (formal_parameters) @constructor_params + body: (constructor_body) @constructor_body) +` + +const javaFunctionAnnotationQuery = ` + (method_declaration + (modifiers + (annotation) @annotation) + name: (identifier) @function_name) +` + +// ResolveFunctions extracts function declarations from Java parse tree +func (r *javaResolvers) ResolveFunctions(tree core.ParseTree) ([]*ast.FunctionDeclarationNode, error) { + data, err := tree.Data() + if err != nil { + return nil, fmt.Errorf("failed to get data from parse tree: %w", err) + } + + var functions []*ast.FunctionDeclarationNode + functionMap := make(map[string]*ast.FunctionDeclarationNode) // To avoid duplicates and allow enhancement + + // Extract method declarations + err = r.extractJavaFunctions(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract Java functions: %w", err) + } + + // Extract constructor declarations + err = r.extractJavaConstructors(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract Java constructors: %w", err) + } + + // Extract function annotations + err = r.extractJavaFunctionAnnotations(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract Java function annotations: %w", err) + } + + // Convert map to slice + for _, function := range functionMap { + functions = append(functions, function) + } + + return functions, nil +} + +// Helper methods for Java function extraction + +func (r *javaResolvers) extractJavaFunctions(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(javaFunctionDefinitionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 4 { + return nil // Skip incomplete matches + } + + var functionNameNode, returnTypeNode, paramsNode, bodyNode, modifiersNode *sitter.Node + + // Parse captures + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "identifier": + functionNameNode = capture.Node + case "formal_parameters": + paramsNode = capture.Node + case "block": + bodyNode = capture.Node + case "modifiers": + modifiersNode = capture.Node + default: + // Return type can be various types (type_identifier, primitive_type, etc.) + if returnTypeNode == nil { + returnTypeNode = capture.Node + } + } + } + + if functionNameNode == nil { + return nil + } + + functionKey := r.generateJavaFunctionKey(functionNameNode, *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(functionNameNode) + + // Set parameters + if paramsNode != nil { + paramNodes := r.extractJavaParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set return type + if returnTypeNode != nil { + functionNode.SetFunctionReturnTypeNode(returnTypeNode) + } + + // Set function body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + // Determine function type and access modifier + parentClassName := r.findParentClassName(functionNameNode, *data) + if parentClassName != "" { + functionNode.SetParentClassName(parentClassName) + functionNode.SetFunctionType(ast.FunctionTypeMethod) + } else { + functionNode.SetFunctionType(ast.FunctionTypeFunction) + } + + // Process modifiers + if modifiersNode != nil { + r.processJavaModifiers(modifiersNode, functionNode, *data) + } else { + functionNode.SetAccessModifier(ast.AccessModifierPackage) // Default in Java + } + + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +func (r *javaResolvers) extractJavaConstructors(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(javaConstructorDefinitionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 3 { + return nil + } + + var constructorNameNode, paramsNode, bodyNode, modifiersNode *sitter.Node + + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "identifier": + constructorNameNode = capture.Node + case "formal_parameters": + paramsNode = capture.Node + case "constructor_body": + bodyNode = capture.Node + case "modifiers": + modifiersNode = capture.Node + } + } + + if constructorNameNode == nil { + return nil + } + + functionKey := r.generateJavaFunctionKey(constructorNameNode, *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(constructorNameNode) + functionNode.SetFunctionType(ast.FunctionTypeConstructor) + + // Set parameters + if paramsNode != nil { + paramNodes := r.extractJavaParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set constructor body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + // Set parent class + parentClassName := r.findParentClassName(constructorNameNode, *data) + if parentClassName != "" { + functionNode.SetParentClassName(parentClassName) + } + + // Process modifiers + if modifiersNode != nil { + r.processJavaModifiers(modifiersNode, functionNode, *data) + } else { + functionNode.SetAccessModifier(ast.AccessModifierPackage) // Default in Java + } + + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +func (r *javaResolvers) extractJavaFunctionAnnotations(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(javaFunctionAnnotationQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 2 { + return nil + } + + var annotationNode, functionNameNode *sitter.Node + for _, capture := range m.Captures { + if capture.Node.Type() == "annotation" { + annotationNode = capture.Node + } else if capture.Node.Type() == "identifier" { + functionNameNode = capture.Node + } + } + + if annotationNode == nil || functionNameNode == nil { + return nil + } + + functionKey := r.generateJavaFunctionKey(functionNameNode, *data) + if functionNode, exists := functionMap[functionKey]; exists { + functionNode.AddDecoratorNode(annotationNode) + } + + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +// Helper methods for Java function processing + +func (r *javaResolvers) extractJavaParameterNodes(parametersNode *sitter.Node) []*sitter.Node { + var paramNodes []*sitter.Node + + if parametersNode == nil { + return paramNodes + } + + for i := 0; i < int(parametersNode.ChildCount()); i++ { + child := parametersNode.Child(i) + if child.Type() == "formal_parameter" || child.Type() == "spread_parameter" { + paramNodes = append(paramNodes, child) + } + } + + return paramNodes +} + +// Note: We are not handling the modifier based on content currently. +func (r *javaResolvers) processJavaModifiers(modifiersNode *sitter.Node, + functionNode *ast.FunctionDeclarationNode, _ []byte) { + if modifiersNode == nil { + return + } + + // Default access modifier + accessModifier := ast.AccessModifierPackage + + for i := 0; i < int(modifiersNode.ChildCount()); i++ { + child := modifiersNode.Child(i) + if child != nil { + switch child.Type() { + case "public": + accessModifier = ast.AccessModifierPublic + case "private": + accessModifier = ast.AccessModifierPrivate + case "protected": + accessModifier = ast.AccessModifierProtected + case "static": + functionNode.SetIsStatic(true) + if functionNode.GetFunctionType() == ast.FunctionTypeMethod { + functionNode.SetFunctionType(ast.FunctionTypeStaticMethod) + } + case "abstract": + functionNode.SetIsAbstract(true) + case "annotation": + functionNode.AddDecoratorNode(child) + } + } + } + + functionNode.SetAccessModifier(accessModifier) +} + +func (r *javaResolvers) generateJavaFunctionKey(functionNameNode *sitter.Node, data []byte) string { + functionName := functionNameNode.Content(data) + parentClassName := r.findParentClassName(functionNameNode, data) + + if parentClassName != "" { + return parentClassName + "." + functionName + } + + // Add line number to distinguish functions with same name in different scopes + lineNumber := functionNameNode.StartPoint().Row + return fmt.Sprintf("%s:%d", functionName, lineNumber) +} diff --git a/lang/java_resolvers_test.go b/lang/java_resolvers_test.go index 5c16fe4..c65cd43 100644 --- a/lang/java_resolvers_test.go +++ b/lang/java_resolvers_test.go @@ -2,6 +2,7 @@ package lang_test import ( "context" + "fmt" "testing" "github.com/safedep/code/core" @@ -25,6 +26,18 @@ var javaImportExpectations = []ImportExpectations{ }, } +var javaFunctionExpectations = map[string][]string{ + "fixtures/Functions.java": { + "FunctionDeclarationNode{Name: MyClassWithFunctions, Type: constructor, Access: public, ParentClass: MyClassWithFunctions}", + "FunctionDeclarationNode{Name: publicMethod, Type: method, Access: public, ParentClass: MyClassWithFunctions}", + "FunctionDeclarationNode{Name: protectedMethod, Type: method, Access: protected, ParentClass: MyClassWithFunctions}", + "FunctionDeclarationNode{Name: privateMethod, Type: method, Access: private, ParentClass: MyClassWithFunctions}", + "FunctionDeclarationNode{Name: staticMethod, Type: static_method, Access: public, ParentClass: MyClassWithFunctions}", + "FunctionDeclarationNode{Name: toString, Type: method, Access: public, ParentClass: MyClassWithFunctions}", + "FunctionDeclarationNode{Name: add, Type: static_method, Access: public, ParentClass: TestFunctions}", + }, +} + func TestJavaLanguageResolvers(t *testing.T) { t.Run("ResolversExists", func(t *testing.T) { l, err := lang.NewJavaLanguage() @@ -34,9 +47,6 @@ func TestJavaLanguageResolvers(t *testing.T) { }) t.Run("ResolveImports", func(t *testing.T) { - l, err := lang.NewJavaLanguage() - assert.NoError(t, err) - importExpectationsMapper := make(map[string][]string) importFilePaths := []string{} for _, ie := range javaImportExpectations { @@ -59,7 +69,7 @@ func TestJavaLanguageResolvers(t *testing.T) { parseTree, err := fileParser.Parse(context.Background(), f) assert.NoError(t, err) - imports, err := l.Resolvers().ResolveImports(parseTree) + imports, err := javaLanguage.Resolvers().ResolveImports(parseTree) assert.NoError(t, err) expectedImports, ok := importExpectationsMapper[f.Name()] @@ -74,4 +84,45 @@ func TestJavaLanguageResolvers(t *testing.T) { }) assert.NoError(t, err) }) + + t.Run("ResolveFunctions", func(t *testing.T) { + var filePaths []string + for path := range javaFunctionExpectations { + filePaths = append(filePaths, path) + } + + javaLanguage, err := lang.NewJavaLanguage() + assert.NoError(t, err) + + fileParser, err := parser.NewParser([]core.Language{javaLanguage}) + assert.NoError(t, err) + + fileSystem, err := fs.NewLocalFileSystem(fs.LocalFileSystemConfig{ + AppDirectories: filePaths, + }) + assert.NoError(t, err) + + err = fileSystem.EnumerateApp(context.Background(), func(f core.File) error { + parseTree, err := fileParser.Parse(context.Background(), f) + assert.NoError(t, err) + + functions, err := javaLanguage.Resolvers().ResolveFunctions(parseTree) + assert.NoError(t, err) + + expectedFunctions, ok := javaFunctionExpectations[f.Name()] + assert.True(t, ok) + + var foundFunctions []string + for _, fun := range functions { + foundFunctions = append(foundFunctions, + fmt.Sprintf("FunctionDeclarationNode{Name: %s, Type: %s, Access: %s, ParentClass: %s}", + fun.FunctionName(), fun.GetFunctionType(), fun.GetAccessModifier(), fun.GetParentClassName())) + } + + assert.ElementsMatch(t, expectedFunctions, foundFunctions) + + return nil + }) + assert.NoError(t, err) + }) } diff --git a/lang/javascript_resolvers.go b/lang/javascript_resolvers.go index 1c360f6..8053a63 100644 --- a/lang/javascript_resolvers.go +++ b/lang/javascript_resolvers.go @@ -26,7 +26,7 @@ const jsWholeModuleImportQuery = ` (import_clause (namespace_import (identifier) @module_alias)) source: (string (string_fragment) @module_name)) - + ; const xyz = await import('xyz) (lexical_declaration (variable_declarator @@ -44,7 +44,7 @@ const jsRequireModuleQuery = ` value: (call_expression function: (identifier) @require_function arguments: (arguments (string (string_fragment) @module_name))))) - + (lexical_declaration (variable_declarator name: (object_pattern @@ -67,9 +67,9 @@ const jsRequireModuleQuery = ` const jsSpecifiedItemImportQuery = ` (import_statement (import_clause - (named_imports - (import_specifier - name: (identifier) @module_item + (named_imports + (import_specifier + name: (identifier) @module_item alias: (identifier)? @module_alias))) source: (string (string_fragment) @module_name)) ` @@ -153,3 +153,467 @@ func (r *javascriptResolvers) ResolveImports(tree core.ParseTree) ([]*ast.Import return imports, err } + +// Tree-Sitter queries for JavaScript function definitions based on actual grammar +const jsFunctionDefinitionQuery = ` + (function_declaration + name: (identifier) @function_name + parameters: (formal_parameters) @function_params + body: (statement_block) @function_body) +` + +const jsArrowFunctionQuery = ` + (variable_declarator + name: (identifier) @function_name + value: (arrow_function + parameters: (_) @function_params + body: (_) @function_body)) + + (assignment_expression + left: (identifier) @function_name + right: (arrow_function + parameters: (_) @function_params + body: (_) @function_body)) +` + +const jsMethodDefinitionQuery = ` + (method_definition + name: (property_identifier) @method_name + parameters: (formal_parameters) @method_params + body: (statement_block) @method_body) + + (method_definition + (decorator)* @decorator + name: (property_identifier) @method_name + parameters: (formal_parameters) @method_params + body: (statement_block) @method_body) +` + +const jsFunctionExpressionQuery = ` + (function_expression + name: (identifier)? @function_name + parameters: (formal_parameters) @function_params + body: (statement_block) @function_body) +` + +// ResolveFunctions extracts function declarations from JavaScript parse tree +func (r *javascriptResolvers) ResolveFunctions(tree core.ParseTree) ([]*ast.FunctionDeclarationNode, error) { + data, err := tree.Data() + if err != nil { + return nil, fmt.Errorf("failed to get data from parse tree: %w", err) + } + + var functions []*ast.FunctionDeclarationNode + functionMap := make(map[string]*ast.FunctionDeclarationNode) // To avoid duplicates + + // Extract regular function declarations + err = r.extractJSFunctions(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract JavaScript functions: %w", err) + } + + // Extract arrow functions + err = r.extractJSArrowFunctions(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract JavaScript arrow functions: %w", err) + } + + // Extract method definitions (class methods) + err = r.extractJSMethods(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract JavaScript methods: %w", err) + } + + // Extract function expressions + err = r.extractJSFunctionExpressions(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract JavaScript function expressions: %w", err) + } + + // Convert map to slice + for _, function := range functionMap { + functions = append(functions, function) + } + + return functions, nil +} + +// Helper methods for JavaScript function extraction + +func (r *javascriptResolvers) extractJSFunctions(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(jsFunctionDefinitionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 3 { + return nil + } + + var functionNameNode, paramsNode, bodyNode *sitter.Node + + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "identifier": + functionNameNode = capture.Node + case "formal_parameters": + paramsNode = capture.Node + case "statement_block": + bodyNode = capture.Node + } + } + + if functionNameNode == nil { + return nil + } + + // Check for async function by looking at the function_declaration parent for async keyword + isAsync := false + current := functionNameNode.Parent() + if current != nil && current.Type() == "function_declaration" { + // Check if any child node is "async" + for i := 0; i < int(current.ChildCount()); i++ { + child := current.Child(i) + if child.Type() == "async" { + isAsync = true + break + } + } + } + + functionKey := r.generateJSFunctionKey(functionNameNode, "", *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(functionNameNode) + + if isAsync { + functionNode.SetFunctionType(ast.FunctionTypeAsync) + functionNode.SetIsAsync(true) + } else { + functionNode.SetFunctionType(ast.FunctionTypeFunction) + } + + // Set parameters + if paramsNode != nil { + paramNodes := r.extractJSParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set function body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + // JavaScript functions are typically public + functionNode.SetAccessModifier(ast.AccessModifierPublic) + + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +func (r *javascriptResolvers) extractJSArrowFunctions(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(jsArrowFunctionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 3 { + return nil + } + + var functionNameNode, paramsNode, bodyNode *sitter.Node + + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "identifier": + if functionNameNode == nil { + functionNameNode = capture.Node + } else if paramsNode == nil && capture.Node != functionNameNode { + paramsNode = capture.Node + } + case "formal_parameters": + if paramsNode == nil { + paramsNode = capture.Node + } + default: + if bodyNode == nil && capture.Node != functionNameNode && capture.Node != paramsNode { + bodyNode = capture.Node + } + } + } + + if functionNameNode == nil { + return nil + } + + // Check for async arrow function by looking at the arrow_function node for async keyword + isAsync := false + current := functionNameNode.Parent() + for current != nil { + if current.Type() == "arrow_function" { + // Check if the arrow function has async keyword + parent := current.Parent() + if parent != nil { + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "async" { + isAsync = true + break + } + } + } + break + } + current = current.Parent() + } + + functionKey := r.generateJSFunctionKey(functionNameNode, "", *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(functionNameNode) + + if isAsync { + functionNode.SetFunctionType(ast.FunctionTypeAsync) + functionNode.SetIsAsync(true) + } else { + functionNode.SetFunctionType(ast.FunctionTypeArrow) + } + + // Set parameters + if paramsNode != nil { + if paramsNode.Type() == "formal_parameters" { + paramNodes := r.extractJSParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } else { + // Single parameter without parentheses + functionNode.AddFunctionParameterNode(paramsNode) + } + } + + // Set function body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + functionNode.SetAccessModifier(ast.AccessModifierPublic) + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +func (r *javascriptResolvers) extractJSMethods(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(jsMethodDefinitionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 3 { + return nil + } + + var methodNameNode, paramsNode, bodyNode *sitter.Node + var decoratorNodes []*sitter.Node + + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "property_identifier": + methodNameNode = capture.Node + case "formal_parameters": + paramsNode = capture.Node + case "statement_block": + bodyNode = capture.Node + case "decorator": + decoratorNodes = append(decoratorNodes, capture.Node) + } + } + + if methodNameNode == nil { + return nil + } + + // Check for async method by looking at the method_definition parent for async keyword + isAsync := false + current := methodNameNode.Parent() + if current != nil && current.Type() == "method_definition" { + // Check if any child node is "async" + for i := 0; i < int(current.ChildCount()); i++ { + child := current.Child(i) + if child.Type() == "async" { + isAsync = true + break + } + } + } + + // Find parent class name + parentClassName := r.findJSParentClassName(methodNameNode, *data) + + functionKey := r.generateJSFunctionKey(methodNameNode, parentClassName, *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(methodNameNode) + + // Check for constructor first, then async, then default to method + methodName := methodNameNode.Content(*data) + if methodName == "constructor" { + functionNode.SetFunctionType(ast.FunctionTypeConstructor) + } else if isAsync { + functionNode.SetFunctionType(ast.FunctionTypeAsync) + functionNode.SetIsAsync(true) + } else { + functionNode.SetFunctionType(ast.FunctionTypeMethod) + } + + if parentClassName != "" { + functionNode.SetParentClassName(parentClassName) + } + + // Set parameters + if paramsNode != nil { + paramNodes := r.extractJSParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set method body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + // Add decorators + for _, decoratorNode := range decoratorNodes { + functionNode.AddDecoratorNode(decoratorNode) + } + + functionNode.SetAccessModifier(ast.AccessModifierPublic) + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +func (r *javascriptResolvers) extractJSFunctionExpressions(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(jsFunctionExpressionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 2 { + return nil + } + + var functionNameNode, paramsNode, bodyNode *sitter.Node + + for _, capture := range m.Captures { + switch capture.Node.Type() { + case "identifier": + if functionNameNode == nil { + functionNameNode = capture.Node + } + case "formal_parameters": + paramsNode = capture.Node + case "statement_block": + bodyNode = capture.Node + } + } + + // Anonymous functions don't have names, skip them + if functionNameNode == nil { + return nil + } + + // Check for async function expression by looking for async keyword + isAsync := false + current := functionNameNode.Parent() + if current != nil && current.Type() == "function_expression" { + parent := current.Parent() + if parent != nil { + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "async" { + isAsync = true + break + } + } + } + } + + functionKey := r.generateJSFunctionKey(functionNameNode, "", *data) + functionNode := ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(functionNameNode) + + if isAsync { + functionNode.SetFunctionType(ast.FunctionTypeAsync) + functionNode.SetIsAsync(true) + } else { + functionNode.SetFunctionType(ast.FunctionTypeFunction) + } + + // Set parameters + if paramsNode != nil { + paramNodes := r.extractJSParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set function body + if bodyNode != nil { + functionNode.SetFunctionBodyNode(bodyNode) + } + + functionNode.SetAccessModifier(ast.AccessModifierPublic) + functionMap[functionKey] = functionNode + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +// Helper methods for JavaScript function processing + +func (r *javascriptResolvers) extractJSParameterNodes(parametersNode *sitter.Node) []*sitter.Node { + var paramNodes []*sitter.Node + + if parametersNode == nil { + return paramNodes + } + + for i := 0; i < int(parametersNode.ChildCount()); i++ { + child := parametersNode.Child(i) + if child.Type() == "identifier" || child.Type() == "assignment_pattern" || + child.Type() == "rest_pattern" || child.Type() == "array_pattern" || + child.Type() == "object_pattern" { + paramNodes = append(paramNodes, child) + } + } + + return paramNodes +} + +func (r *javascriptResolvers) findJSParentClassName(node *sitter.Node, data []byte) string { + if node == nil { + return "" + } + + current := node.Parent() + for current != nil { + if current.Type() == "class_declaration" { + nameNode := current.ChildByFieldName("name") + if nameNode != nil { + return nameNode.Content(data) + } + } + current = current.Parent() + } + + return "" +} + +func (r *javascriptResolvers) generateJSFunctionKey(functionNameNode *sitter.Node, parentClassName string, data []byte) string { + functionName := functionNameNode.Content(data) + + if parentClassName != "" { + return parentClassName + "." + functionName + } + + // Add line number to distinguish functions with same name in different scopes + lineNumber := functionNameNode.StartPoint().Row + return fmt.Sprintf("%s:%d", functionName, lineNumber) +} diff --git a/lang/javascript_resolvers_test.go b/lang/javascript_resolvers_test.go index bbee043..3e16b04 100644 --- a/lang/javascript_resolvers_test.go +++ b/lang/javascript_resolvers_test.go @@ -2,6 +2,7 @@ package lang_test import ( "context" + "fmt" "testing" "github.com/safedep/code/core" @@ -52,6 +53,20 @@ var javascriptImportExpectations = []ImportExpectations{ }, } +var javascriptFunctionExpectations = map[string][]string{ + "fixtures/functions.js": { + "FunctionDeclarationNode{Name: declaredFunction, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: arrowFunction, Type: arrow, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: asyncFunction, Type: async, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: constructor, Type: constructor, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: myMethod, Type: method, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: staticMethod, Type: method, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: myProperty, Type: method, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: myDecorator, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: decoratedMethod, Type: method, Access: public, ParentClass: ClassWithDecorator}", + }, +} + func TestJavascriptLanguageResolvers(t *testing.T) { t.Run("ResolversExists", func(t *testing.T) { l, err := lang.NewJavascriptLanguage() @@ -61,9 +76,6 @@ func TestJavascriptLanguageResolvers(t *testing.T) { }) t.Run("ResolveImports", func(t *testing.T) { - l, err := lang.NewJavascriptLanguage() - assert.NoError(t, err) - importExpectationsMapper := make(map[string][]string) importFilePaths := []string{} for _, ie := range javascriptImportExpectations { @@ -86,7 +98,7 @@ func TestJavascriptLanguageResolvers(t *testing.T) { parseTree, err := fileParser.Parse(context.Background(), f) assert.NoError(t, err) - imports, err := l.Resolvers().ResolveImports(parseTree) + imports, err := javascriptLanguage.Resolvers().ResolveImports(parseTree) assert.NoError(t, err) expectedImports, ok := importExpectationsMapper[f.Name()] @@ -101,4 +113,45 @@ func TestJavascriptLanguageResolvers(t *testing.T) { }) assert.NoError(t, err) }) + + t.Run("ResolveFunctions", func(t *testing.T) { + var filePaths []string + for path := range javascriptFunctionExpectations { + filePaths = append(filePaths, path) + } + + javascriptLanguage, err := lang.NewJavascriptLanguage() + assert.NoError(t, err) + + fileParser, err := parser.NewParser([]core.Language{javascriptLanguage}) + assert.NoError(t, err) + + fileSystem, err := fs.NewLocalFileSystem(fs.LocalFileSystemConfig{ + AppDirectories: filePaths, + }) + assert.NoError(t, err) + + err = fileSystem.EnumerateApp(context.Background(), func(f core.File) error { + parseTree, err := fileParser.Parse(context.Background(), f) + assert.NoError(t, err) + + functions, err := javascriptLanguage.Resolvers().ResolveFunctions(parseTree) + assert.NoError(t, err) + + expectedFunctions, ok := javascriptFunctionExpectations[f.Name()] + assert.True(t, ok) + + var foundFunctions []string + for _, fun := range functions { + foundFunctions = append(foundFunctions, + fmt.Sprintf("FunctionDeclarationNode{Name: %s, Type: %s, Access: %s, ParentClass: %s}", + fun.FunctionName(), fun.GetFunctionType(), fun.GetAccessModifier(), fun.GetParentClassName())) + } + + assert.ElementsMatch(t, expectedFunctions, foundFunctions) + + return nil + }) + assert.NoError(t, err) + }) } diff --git a/lang/python_resolvers.go b/lang/python_resolvers.go index 1d1762c..2bfe46f 100644 --- a/lang/python_resolvers.go +++ b/lang/python_resolvers.go @@ -407,6 +407,233 @@ func (r *pythonResolvers) extractBaseClassNodes(superclassesNode *sitter.Node) [ return baseClassNodes } +// Tree-Sitter queries for Python function definitions +const pyFunctionDefinitionQuery = ` + (function_definition + name: (identifier) @function_name + parameters: (parameters) @function_params + body: (block) @function_body + return_type: (type)? @return_type) + + (function_definition + name: (identifier) @function_name + parameters: (parameters) @function_params + body: (block) @function_body) +` + +const pyFunctionDecoratorQuery = ` + (decorated_definition + (decorator + (identifier) @decorator_name) @decorator + definition: (function_definition + name: (identifier) @function_name)) +` + +// ResolveFunctions extracts function declarations from the parse tree +func (r *pythonResolvers) ResolveFunctions(tree core.ParseTree) ([]*ast.FunctionDeclarationNode, error) { + data, err := tree.Data() + if err != nil { + return nil, fmt.Errorf("failed to get data from parse tree: %w", err) + } + + var functions []*ast.FunctionDeclarationNode + functionMap := make(map[string]*ast.FunctionDeclarationNode) // To avoid duplicates and allow enhancement + + // Extract basic function definitions + err = r.extractFunctionDefinitions(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract function definitions: %w", err) + } + + // Async functions are now handled in the main function extraction + + // Extract function decorators + err = r.extractFunctionDecorators(data, tree, functionMap) + if err != nil { + return nil, fmt.Errorf("failed to extract function decorators: %w", err) + } + + // Convert map to slice and ensure all functions have proper access modifiers + for _, function := range functionMap { + // Ensure every function has a proper access modifier (not Unknown) + currentModifier := function.GetAccessModifier() + if currentModifier == ast.AccessModifierUnknown || currentModifier == "" { + function.SetAccessModifier(ast.AccessModifierPublic) + } + functions = append(functions, function) + } + + return functions, nil +} + +// Helper methods for function extraction + +func (r *pythonResolvers) extractFunctionDefinitions(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(pyFunctionDefinitionQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 3 { + return nil // Skip incomplete matches + } + + functionNameNode := m.Captures[0].Node + functionName := functionNameNode.Content(*data) + + // Create or get existing function + var functionNode *ast.FunctionDeclarationNode + functionKey := r.generateFunctionKey(functionNameNode, *data) + if existing, exists := functionMap[functionKey]; exists { + functionNode = existing + } else { + functionNode = ast.NewFunctionDeclarationNode(data) + functionNode.SetFunctionNameNode(functionNameNode) + functionMap[functionKey] = functionNode + } + + // Set function parameters + if len(m.Captures) >= 2 && m.Captures[1].Node.Type() == "parameters" { + paramsNode := m.Captures[1].Node + paramNodes := r.extractParameterNodes(paramsNode) + functionNode.SetFunctionParameterNodes(paramNodes) + } + + // Set function body + if len(m.Captures) >= 3 && m.Captures[2].Node.Type() == "block" { + functionNode.SetFunctionBodyNode(m.Captures[2].Node) + } + + // Set return type if present + if len(m.Captures) >= 4 && m.Captures[3].Node.Type() == "type" { + functionNode.SetFunctionReturnTypeNode(m.Captures[3].Node) + } + + // Check for async function by looking at the function_definition parent for async keyword + isAsync := false + current := functionNameNode.Parent() + if current != nil && current.Type() == "function_definition" { + // Check if any child node is "async" + for i := 0; i < int(current.ChildCount()); i++ { + child := current.Child(i) + if child.Type() == "async" { + isAsync = true + break + } + } + } + + // Determine function type based on context + parentClassName := r.findParentClassName(functionNameNode, *data) + if parentClassName != "" { + functionNode.SetParentClassName(parentClassName) + if functionName == "__init__" { + functionNode.SetFunctionType(ast.FunctionTypeConstructor) + } else { + functionNode.SetFunctionType(ast.FunctionTypeMethod) + } + } else { + if isAsync { + functionNode.SetFunctionType(ast.FunctionTypeAsync) + functionNode.SetIsAsync(true) + } else { + functionNode.SetFunctionType(ast.FunctionTypeFunction) + } + } + + // Python functions are typically public by default + functionNode.SetAccessModifier(ast.AccessModifierPublic) + + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +func (r *pythonResolvers) extractFunctionDecorators(data *[]byte, tree core.ParseTree, + functionMap map[string]*ast.FunctionDeclarationNode) error { + queryRequestItems := []ts.QueryItem{ + ts.NewQueryItem(pyFunctionDecoratorQuery, func(m *sitter.QueryMatch) error { + if len(m.Captures) < 3 { + return nil + } + + var decoratorNode, functionNameNode *sitter.Node + for _, capture := range m.Captures { + if capture.Node.Type() == "decorator" { + decoratorNode = capture.Node + } else if capture.Node.Type() == "identifier" { + functionNameNode = capture.Node + } + } + + if decoratorNode == nil || functionNameNode == nil { + return nil + } + + functionKey := r.generateFunctionKey(functionNameNode, *data) + if functionNode, exists := functionMap[functionKey]; exists { + functionNode.AddDecoratorNode(decoratorNode) + + // Check for special decorators + decoratorName := "" + for _, capture := range m.Captures { + if capture.Node.Type() == "identifier" && capture.Node.Parent().Type() == "decorator" { + decoratorName = capture.Node.Content(*data) + break + } + } + + switch decoratorName { + case "staticmethod": + functionNode.SetIsStatic(true) + functionNode.SetFunctionType(ast.FunctionTypeStaticMethod) + case "abstractmethod": + functionNode.SetIsAbstract(true) + } + + // Ensure access modifier is set to public for decorated functions + functionNode.SetAccessModifier(ast.AccessModifierPublic) + } + + return nil + }), + } + + return ts.ExecuteQueries(ts.NewQueriesRequest(r.language, queryRequestItems), data, tree) +} + +// Helper methods + +func (r *pythonResolvers) extractParameterNodes(parametersNode *sitter.Node) []*sitter.Node { + var paramNodes []*sitter.Node + + if parametersNode == nil { + return paramNodes + } + + for i := 0; i < int(parametersNode.ChildCount()); i++ { + child := parametersNode.Child(i) + if child.Type() == "identifier" || child.Type() == "typed_parameter" || child.Type() == "default_parameter" { + paramNodes = append(paramNodes, child) + } + } + + return paramNodes +} + +func (r *pythonResolvers) generateFunctionKey(functionNameNode *sitter.Node, data []byte) string { + functionName := functionNameNode.Content(data) + parentClassName := r.findParentClassName(functionNameNode, data) + + if parentClassName != "" { + return parentClassName + "." + functionName + } + + // Add line number to distinguish functions with same name in different scopes + lineNumber := functionNameNode.StartPoint().Row + return fmt.Sprintf("%s:%d", functionName, lineNumber) +} + func (r *pythonResolvers) findParentClassName(node *sitter.Node, data []byte) string { if node == nil { return "" diff --git a/lang/python_resolvers_test.go b/lang/python_resolvers_test.go index 2ad748c..1a0e8d3 100644 --- a/lang/python_resolvers_test.go +++ b/lang/python_resolvers_test.go @@ -2,6 +2,7 @@ package lang_test import ( "context" + "fmt" "testing" "github.com/safedep/code/core" @@ -39,6 +40,23 @@ var pythonImportExpectations = []ImportExpectations{ }, } +var pythonFunctionExpectations = map[string][]string{ + "fixtures/functions.py": { + "FunctionDeclarationNode{Name: simple_function, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: function_with_args, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: my_async_function, Type: async, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: __init__, Type: constructor, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: instance_method, Type: method, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: static_method, Type: static_method, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: class_method, Type: method, Access: public, ParentClass: MyClass}", + "FunctionDeclarationNode{Name: my_decorator, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: wrapper, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: decorated_function, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: outer_function, Type: function, Access: public, ParentClass: }", + "FunctionDeclarationNode{Name: inner_function, Type: function, Access: public, ParentClass: }", + }, +} + func TestPythonLanguageResolvers(t *testing.T) { t.Run("ResolversExists", func(t *testing.T) { l, err := lang.NewPythonLanguage() @@ -48,9 +66,6 @@ func TestPythonLanguageResolvers(t *testing.T) { }) t.Run("ResolveImports", func(t *testing.T) { - l, err := lang.NewPythonLanguage() - assert.NoError(t, err) - importExpectationsMapper := make(map[string][]string) importFilePaths := []string{} for _, ie := range pythonImportExpectations { @@ -73,7 +88,7 @@ func TestPythonLanguageResolvers(t *testing.T) { parseTree, err := fileParser.Parse(context.Background(), f) assert.NoError(t, err) - imports, err := l.Resolvers().ResolveImports(parseTree) + imports, err := pythonLanguage.Resolvers().ResolveImports(parseTree) assert.NoError(t, err) expectedImports, ok := importExpectationsMapper[f.Name()] @@ -86,6 +101,48 @@ func TestPythonLanguageResolvers(t *testing.T) { return err }) + + assert.NoError(t, err) + }) + + t.Run("ResolveFunctions", func(t *testing.T) { + pythonLanguage, err := lang.NewPythonLanguage() + assert.NoError(t, err) + + var filePaths []string + for path := range pythonFunctionExpectations { + filePaths = append(filePaths, path) + } + + fileParser, err := parser.NewParser([]core.Language{pythonLanguage}) + assert.NoError(t, err) + + fileSystem, err := fs.NewLocalFileSystem(fs.LocalFileSystemConfig{ + AppDirectories: filePaths, + }) + assert.NoError(t, err) + + err = fileSystem.EnumerateApp(context.Background(), func(f core.File) error { + parseTree, err := fileParser.Parse(context.Background(), f) + assert.NoError(t, err) + + functions, err := pythonLanguage.Resolvers().ResolveFunctions(parseTree) + assert.NoError(t, err) + + expectedFunctions, ok := pythonFunctionExpectations[f.Name()] + assert.True(t, ok) + + var foundFunctions []string + for _, fun := range functions { + foundFunctions = append(foundFunctions, + fmt.Sprintf("FunctionDeclarationNode{Name: %s, Type: %s, Access: %s, ParentClass: %s}", + fun.FunctionName(), fun.GetFunctionType(), fun.GetAccessModifier(), fun.GetParentClassName())) + } + + assert.ElementsMatch(t, expectedFunctions, foundFunctions) + + return nil + }) assert.NoError(t, err) }) }