Compare commits

..

1 Commits

Author SHA1 Message Date
Jeremy Cohen
93572b9291 Support 'scripts' for python models 2022-09-28 20:25:53 +02:00
3 changed files with 26 additions and 36 deletions

View File

@@ -370,11 +370,6 @@ class Compiler:
compiled_node = _compiled_type_for(node).from_dict(data)
if compiled_node.language == ModelLanguage.python:
# TODO could we also 'minify' this code at all? just aesthetic, not functional
# quoating seems like something very specific to sql so far
# for all python implementations we are seeing there's no quating.
# TODO try to find better way to do this, given that
original_quoting = self.config.quoting
self.config.quoting = {key: False for key in original_quoting.keys()}
context = self._create_node_context(compiled_node, manifest, extra_context)
@@ -385,7 +380,19 @@ class Compiler:
node,
)
# we should NOT jinja render the python model's 'raw code'
compiled_node.compiled_code = f"{node.raw_code}\n\n{postfix}"
# if the user didn't specify an explicit `model(dbt, session)` function,
# we're going to treat the user code as a "script" and wrap it in that function now.
# TODO: this is the jankiest way of doing it, with zero AST magic
if node.meta.get("missing_model_function") is True:
raw_code_lines = node.raw_code.strip().split("\n")
raw_code_lines[-1] = f"return {raw_code_lines[-1]}"
raw_code_indented = "\n ".join(raw_code_lines)
model_code = f"def model(dbt, session):\n {raw_code_indented}"
else:
model_code = node.raw_code
compiled_node.compiled_code = f"{model_code}\n\n{postfix}"
# restore quoting settings in the end since context is lazy evaluated
self.config.quoting = original_quoting

View File

@@ -1,26 +1,3 @@
{% macro build_dbt_relation_obj(model) %}
class dbtRelation:
"""
dbt.ref('model_a').rel -> 'database.schema.model_a'
str(dbt.ref('model_a')) -> same
dbt.ref('model_a').df -> DataFrame pointing to 'database.schema.model_a'
dbt.ref('model_a')() -> same
Could we make this return .df for just dbt.ref('model_a'),
with no add'l func call, or is that impossible with Python classes ???
"""
def __init__(self, relation_name, dbt_load_df_function):
self.rel = relation_name
self.df = dbt_load_df_function(relation_name)
def __str__(self):
return self.relation_name
def __call__(self):
return self.df
{% endmacro %}
{% macro build_ref_function(model) %}
{%- set ref_dict = {} -%}
@@ -29,10 +6,10 @@ class dbtRelation:
{%- do ref_dict.update({_ref | join("."): resolved.quote(database=False, schema=False, identifier=False) | string}) -%}
{%- endfor -%}
def ref(*args, dbt_load_df_function):
def ref(*args,dbt_load_df_function):
refs = {{ ref_dict | tojson }}
key = ".".join(args)
return dbtRelation(refs[key], dbt_load_df_function)
return dbt_load_df_function(refs[key])
{% endmacro %}
@@ -47,7 +24,7 @@ def ref(*args, dbt_load_df_function):
def source(*args, dbt_load_df_function):
sources = {{ source_dict | tojson }}
key = ".".join(args)
return dbtRelation(sources[key], dbt_load_df_function)
return dbt_load_df_function(sources[key])
{% endmacro %}
@@ -70,7 +47,6 @@ config_dict = {{ config_dict }}
# COMMAND ----------
# this part is dbt logic for get ref work, do not modify
{{ build_dbt_relation_obj(model ) }}
{{ build_ref_function(model ) }}
{{ build_source_function(model ) }}
{{ build_config_dict(model) }}

View File

@@ -60,8 +60,8 @@ class PythonValidationVisitor(ast.NodeVisitor):
)
def check_error(self, node):
if self.num_model_def != 1:
raise ParsingException("dbt only allow one model defined per python file", node=node)
if self.num_model_def > 1:
raise ParsingException("dbt only allows one model defined per python file", node=node)
if len(self.dbt_errors) != 0:
raise ParsingException("\n".join(self.dbt_errors), node=node)
@@ -113,7 +113,7 @@ class PythonParseVisitor(ast.NodeVisitor):
return arg_literals, kwarg_literals
def visit_Call(self, node: ast.Call) -> None:
# check weather the current call could be a dbt function call
# check whether the current call could be a dbt function call
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
func_name = self._flatten_attr(node.func)
# check weather the current call really is a dbt function call
@@ -204,6 +204,13 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
dbtValidator.visit(tree)
dbtValidator.check_error(node)
# if the user didn't specify an explicit `model(dbt, session)` function,
# we're going to treat the user code as a "script" to be wrapped in that function at compile time.
# for now, we just need to recognize that fact, and save it to the node.
if dbtValidator.num_model_def == 0:
# TODO: this is silly, put this somewhere better (outside of user space)
node.meta["missing_model_function"] = True
dbtParser = PythonParseVisitor(node)
dbtParser.visit(tree)
config_keys_used = []