Compare commits

...

3 Commits

Author SHA1 Message Date
Chenyu Li
18f846d745 update comment 2022-09-28 22:40:10 -07:00
Chenyu Li
0d7010fd71 end2end working version of jinja/python mix 2022-09-28 22:36:21 -07:00
Chenyu Li
44b67430cb mix jinja and python 2022-09-28 22:10:50 -07:00
4 changed files with 40 additions and 25 deletions

View File

@@ -384,8 +384,13 @@ class Compiler:
context,
node,
)
# we should NOT jinja render the python model's 'raw code'
compiled_node.compiled_code = f"{node.raw_code}\n\n{postfix}"
compiled_code = jinja.get_rendered(
node.raw_code,
{"ref": context["ref"], "source": context["source"]},
node,
)
compiled_node.compiled_code = f"{compiled_code}\n\n{postfix}"
# restore quoting settings in the end since context is lazy evaluated
self.config.quoting = original_quoting

View File

@@ -47,8 +47,6 @@ config_dict = {{ config_dict }}
# COMMAND ----------
# this part is dbt logic for get ref work, do not modify
{{ build_ref_function(model ) }}
{{ build_source_function(model ) }}
{{ build_config_dict(model) }}
class config:
@@ -70,8 +68,6 @@ class this:
class dbtObj:
def __init__(self, load_df_function) -> None:
self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function)
self.ref = lambda *args: ref(*args, dbt_load_df_function=load_df_function)
self.config = config
self.this = this()
self.is_incremental = {{ is_incremental() }}

View File

@@ -32,8 +32,8 @@ from dbt.dataclass_schema import ValidationError
from dbt.exceptions import ParsingException, validator_error_message, UndefinedMacroException
dbt_function_key_words = set(["ref", "source", "config", "get"])
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
dbt_function_key_words = set(["config", "get"])
dbt_function_full_names = set(["dbt.config", "dbt.config.get"])
class PythonValidationVisitor(ast.NodeVisitor):
@@ -163,18 +163,24 @@ def merge_packages(original_packages_with_version, new_packages):
return original_packages_with_version + list(set(additional_packages))
def verify_python_model_code(node):
def render_python_model_code(node):
# TODO: add a test for this
try:
rendered_python = get_rendered(
node.raw_code,
{},
{
"ref": lambda *arg: None,
"source": lambda *arg: None,
},
node,
)
if rendered_python != node.raw_code:
raise ParsingException("")
return rendered_python
# if rendered_python != node.raw_code:
# raise ParsingException("")
except (UndefinedMacroException, ParsingException):
raise ParsingException("No jinja in python model code is allowed", node=node)
raise ParsingException(
"No jinja other than 'source' and 'ref' in python model code is allowed", node=node
)
class ModelParser(SimpleSQLParser[ParsedModelNode]):
@@ -222,9 +228,17 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
if node.language == ModelLanguage.python:
try:
verify_python_model_code(node)
# this part will take care of the get function
original_code = node.raw_code
python_code = render_python_model_code(node)
node.raw_code = python_code
context = self._context_for(node, config)
self.parse_python_model(node, config, context)
node.raw_code = original_code
# source, ref in jinja
get_rendered(node.raw_code, context, node, capture_macros=True)
self.update_parsed_node_config(node, config, context=context)
except ValidationError as exc:

View File

@@ -549,19 +549,19 @@ def model(dbt, session):
from torch import b
import textblob.text
import sklearn
df0 = pandas(dbt.ref("a_model"))
df1 = dbt.ref("my_sql_model").task.limit(2)
df2 = dbt.ref("my_sql_model_1")
df3 = dbt.ref("my_sql_model_2")
df4 = dbt.source("test", 'table1').limit(max = [max(dbt.ref('something'))])
df5 = [dbt.ref('test1')]
a_dict = {'test2' : dbt.ref('test2')}
df5 = anotherfunction({'test2' : dbt.ref('test3')})
df6 = [somethingelse.ref(dbt.ref("test4"))]
df0 = pandas("{{ref('a_model')}}")
df1 = session.table("{{ref('my_sql_model')}}").task.limit(2)
df2 = session.table('{{ref("my_sql_model_1")}}')
df3 = session.table('{{ref("my_sql_model_2")}}')
df4 = session.table('{{source("test", "table1")}}').limit(max = [max(session.table("{{ref('something')}}"))])
df5 = [session.table("{{ref('test1')}}")]
a_dict = {'test2' : session.table("{{ref('test2')}}")}
df5 = onefunction({'test2' : session.table("{{ref('test3')}}")})
df6 = [somethingelse.ref(session.table("{{ref('test4')}}"))]
df = df.limit(2)
return df
return df
"""
block = self.file_block_for(py_code, 'nested/py_model.py')
self.parser.manifest.files[block.file.file_id] = block.file