Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
schema_from_bigquery_table,
)
from bigframes_vendored.ibis.backends.bigquery.datatypes import BigQuerySchema
from bigframes_vendored.ibis.backends.bigquery.udf.core import (
PythonToJavaScriptTranslator,
)
from bigframes_vendored.ibis.backends.sql import SQLBackend
from bigframes_vendored.ibis.backends.sql.compilers import BigQueryCompiler
from bigframes_vendored.ibis.backends.sql.datatypes import BigQueryType
Expand Down Expand Up @@ -731,15 +728,7 @@ def compile(
):
"""Compile an Ibis expression to a SQL string."""
query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)
udf_sources = []
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
if sql := compile_func(udf_node):
udf_sources.append(sql.sql(self.name, pretty=True))

sql = ";\n".join([*udf_sources, query.sql(dialect=self.name, pretty=True)])
sql = query.sql(dialect=self.name, pretty=True)
self._log(sql)
return sql

Expand Down Expand Up @@ -1186,68 +1175,6 @@ def _clean_up_cached_table(self, name):
force=True,
)

def _get_udf_source(self, udf_node: ops.ScalarUDF):
name = type(udf_node).__name__
type_mapper = self.compiler.udf_type_mapper

body = PythonToJavaScriptTranslator(udf_node.__func__).compile()
config = udf_node.__config__
libraries = config.get("libraries", [])

signature = [
sge.ColumnDef(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
kind=type_mapper.from_ibis(param.annotation.pattern.dtype),
)
for name, param in udf_node.__signature__.parameters.items()
]

lines = ['"""']

if config.get("strict", True):
lines.append('"use strict";')

lines += [
body,
"",
f"return {udf_node.__func_name__}({', '.join(udf_node.argnames)});",
'"""',
]

func = sge.Create(
kind="FUNCTION",
this=sge.UserDefinedFunction(
this=sg.to_identifier(name), expressions=signature, wrapped=True
),
# not exactly what I had in mind, but it works
#
# quoting is too simplistic to handle multiline strings
expression=sge.Var(this="\n".join(lines)),
exists=False,
properties=sge.Properties(
expressions=[
sge.TemporaryProperty(),
sge.ReturnsProperty(this=type_mapper.from_ibis(udf_node.dtype)),
sge.StabilityProperty(
this="IMMUTABLE" if config.get("determinism") else "VOLATILE"
),
sge.LanguageProperty(this=sg.to_identifier("js")),
]
+ [
sge.Property(
this=sg.to_identifier("library"),
value=self.compiler.f.array(*libraries),
)
]
* bool(libraries)
),
)

return func

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None:
return self._get_udf_source(udf_node)

def _register_udfs(self, expr: ir.Expr) -> None:
"""No op because UDFs made with CREATE TEMPORARY FUNCTION must be followed by a query."""

Expand Down
Empty file.
Loading
Loading