mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-19 06:31:27 +00:00
Compare commits
1 Commits
enable-pos
...
jerco/mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae73ce575e |
@@ -30,11 +30,13 @@ from dbt.graph import Graph
|
||||
from dbt.events.functions import fire_event
|
||||
from dbt.events.types import FoundStats, WritingInjectedSQLForNode
|
||||
from dbt.events.contextvars import get_node_info
|
||||
from dbt.node_types import NodeType, ModelLanguage
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.events.format import pluralize
|
||||
import dbt.tracking
|
||||
import dbt.task.list as list_task
|
||||
|
||||
from dbt.parser.languages import get_language_provider_by_name
|
||||
|
||||
graph_file_name = "graph.gpickle"
|
||||
|
||||
|
||||
@@ -350,25 +352,22 @@ class Compiler:
|
||||
if extra_context is None:
|
||||
extra_context = {}
|
||||
|
||||
if node.language == ModelLanguage.python:
|
||||
context = self._create_node_context(node, manifest, extra_context)
|
||||
data = node.to_dict(omit_none=True)
|
||||
data.update(
|
||||
{
|
||||
"compiled": False,
|
||||
"compiled_code": None,
|
||||
"compiled_language": None,
|
||||
"extra_ctes_injected": False,
|
||||
"extra_ctes": [],
|
||||
}
|
||||
)
|
||||
|
||||
postfix = jinja.get_rendered(
|
||||
"{{ py_script_postfix(model) }}",
|
||||
context,
|
||||
node,
|
||||
)
|
||||
# we should NOT jinja render the python model's 'raw code'
|
||||
node.compiled_code = f"{node.raw_code}\n\n{postfix}"
|
||||
|
||||
else:
|
||||
context = self._create_node_context(node, manifest, extra_context)
|
||||
node.compiled_code = jinja.get_rendered(
|
||||
node.raw_code,
|
||||
context,
|
||||
node,
|
||||
)
|
||||
context = self._create_node_context(node, manifest, extra_context)
|
||||
provider = get_language_provider_by_name(node.language)
|
||||
|
||||
node.compiled_code = provider.get_compiled_code(node, context)
|
||||
node.compiled_language = provider.compiled_language()
|
||||
node.compiled = True
|
||||
|
||||
# relation_name is set at parse time, except for tests without store_failures,
|
||||
@@ -506,6 +505,8 @@ class Compiler:
|
||||
fire_event(WritingInjectedSQLForNode(node_info=get_node_info()))
|
||||
|
||||
if node.compiled_code:
|
||||
# TODO: should compiled_path extension depend on the compiled_language?
|
||||
# e.g. "model.prql" (source) -> "model.sql" (compiled)
|
||||
node.compiled_path = node.write_node(
|
||||
self.config.target_path, "compiled", node.compiled_code
|
||||
)
|
||||
|
||||
@@ -1331,7 +1331,7 @@ class ModelContext(ProviderContext):
|
||||
# only doing this in sql model for backward compatible
|
||||
if (
|
||||
getattr(self.model, "extra_ctes_injected", None)
|
||||
and self.model.language == ModelLanguage.sql # type: ignore[union-attr]
|
||||
and self.model.compiled_language == ModelLanguage.sql # type: ignore[union-attr]
|
||||
):
|
||||
# TODO CT-211
|
||||
return self.model.compiled_code # type: ignore[union-attr]
|
||||
|
||||
@@ -22,6 +22,7 @@ class ParseFileType(StrEnum):
|
||||
Documentation = "docs"
|
||||
Schema = "schema"
|
||||
Hook = "hook" # not a real filetype, from dbt_project.yml
|
||||
language: str = "sql"
|
||||
|
||||
|
||||
parse_file_type_to_parser = {
|
||||
@@ -192,6 +193,7 @@ class SourceFile(BaseSourceFile):
|
||||
docs: List[str] = field(default_factory=list)
|
||||
macros: List[str] = field(default_factory=list)
|
||||
env_vars: List[str] = field(default_factory=list)
|
||||
language: str = "sql"
|
||||
|
||||
@classmethod
|
||||
def big_seed(cls, path: FilePath) -> "SourceFile":
|
||||
|
||||
@@ -517,6 +517,7 @@ class CompiledNode(ParsedNode):
|
||||
so all ManifestNodes except SeedNode."""
|
||||
|
||||
language: str = "sql"
|
||||
compiled_language: str = "sql"
|
||||
refs: List[RefArgs] = field(default_factory=list)
|
||||
sources: List[List[str]] = field(default_factory=list)
|
||||
metrics: List[List[str]] = field(default_factory=list)
|
||||
|
||||
@@ -80,9 +80,12 @@ class SelectionCriteria:
|
||||
|
||||
@classmethod
|
||||
def default_method(cls, value: str) -> MethodName:
|
||||
from dbt.parser.languages import get_file_extensions
|
||||
|
||||
extensions = tuple(get_file_extensions() + [".csv"])
|
||||
if _probably_path(value):
|
||||
return MethodName.Path
|
||||
elif value.lower().endswith((".sql", ".py", ".csv")):
|
||||
elif value.lower().endswith(extensions):
|
||||
return MethodName.File
|
||||
else:
|
||||
return MethodName.FQN
|
||||
|
||||
@@ -49,7 +49,7 @@ def source(*args, dbt_load_df_function):
|
||||
{% set config_dbt_used = zip(model.config.config_keys_used, model.config.config_keys_defaults) | list %}
|
||||
{%- for key, default in config_dbt_used -%}
|
||||
{# weird type testing with enum, would be much easier to write this logic in Python! #}
|
||||
{%- if key == "language" -%}
|
||||
{%- if key in ("language", "compiled_language") -%}
|
||||
{%- set value = "python" -%}
|
||||
{%- endif -%}
|
||||
{%- set value = model.config.get(key, default) -%}
|
||||
|
||||
@@ -87,5 +87,8 @@ class RunHookType(StrEnum):
|
||||
|
||||
|
||||
class ModelLanguage(StrEnum):
|
||||
# TODO: how to make this dynamic?
|
||||
python = "python"
|
||||
sql = "sql"
|
||||
ibis = "ibis"
|
||||
prql = "prql"
|
||||
|
||||
@@ -23,6 +23,8 @@ from dbt import hooks
|
||||
from dbt.node_types import NodeType, ModelLanguage
|
||||
from dbt.parser.search import FileBlock
|
||||
|
||||
from dbt.parser.languages import get_language_providers, get_language_provider_by_name
|
||||
|
||||
# internally, the parser may store a less-restrictive type that will be
|
||||
# transformed into the final type. But it will have to be derived from
|
||||
# ParsedNode to be operable.
|
||||
@@ -157,7 +159,7 @@ class ConfiguredParser(
|
||||
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
|
||||
|
||||
def _create_error_node(
|
||||
self, name: str, path: str, original_file_path: str, raw_code: str, language: str = "sql"
|
||||
self, name: str, path: str, original_file_path: str, raw_code: str, language: str
|
||||
) -> UnparsedNode:
|
||||
"""If we hit an error before we've actually parsed a node, provide some
|
||||
level of useful information by attaching this to the exception.
|
||||
@@ -189,12 +191,24 @@ class ConfiguredParser(
|
||||
"""
|
||||
if name is None:
|
||||
name = block.name
|
||||
if block.path.relative_path.endswith(".py"):
|
||||
language = ModelLanguage.python
|
||||
|
||||
# this is pretty silly, but we need "sql" to be the default
|
||||
# even for seeds etc (.csv) -- otherwise this breaks a lot of tests
|
||||
language = ModelLanguage.sql
|
||||
|
||||
for provider in get_language_providers():
|
||||
# TODO: decouple 1:1 mapping between file extension and modeling language
|
||||
# e.g. ibis models also want to be '.py', and non-Jinja SQL models want to be '.sql'
|
||||
# I could imagine supporting IPython-style 'magic', e.g. `%ibis` or `%prql`
|
||||
if block.contents.startswith(f"%{provider.name()}"):
|
||||
language = ModelLanguage[provider.name()]
|
||||
break
|
||||
elif block.path.relative_path.endswith(provider.file_ext()):
|
||||
language = ModelLanguage[provider.name()]
|
||||
|
||||
# Standard Python models are materialized as 'table' by default
|
||||
if language == ModelLanguage.python:
|
||||
config.add_config_call({"materialized": "table"})
|
||||
else:
|
||||
# this is not ideal but we have a lot of tests to adjust if don't do it
|
||||
language = ModelLanguage.sql
|
||||
|
||||
dct = {
|
||||
"alias": name,
|
||||
@@ -223,23 +237,13 @@ class ConfiguredParser(
|
||||
path=path,
|
||||
original_file_path=block.path.original_file_path,
|
||||
raw_code=block.contents,
|
||||
language=language,
|
||||
)
|
||||
raise DictParseError(exc, node=node)
|
||||
|
||||
def _context_for(self, parsed_node: IntermediateNode, config: ContextConfig) -> Dict[str, Any]:
|
||||
return generate_parser_model_context(parsed_node, self.root_project, self.manifest, config)
|
||||
|
||||
def render_with_context(self, parsed_node: IntermediateNode, config: ContextConfig):
|
||||
# Given the parsed node and a ContextConfig to use during parsing,
|
||||
# render the node's sql with macro capture enabled.
|
||||
# Note: this mutates the config object when config calls are rendered.
|
||||
context = self._context_for(parsed_node, config)
|
||||
|
||||
# this goes through the process of rendering, but just throws away
|
||||
# the rendered result. The "macro capture" is the point?
|
||||
get_rendered(parsed_node.raw_code, context, parsed_node, capture_macros=True)
|
||||
return context
|
||||
|
||||
# This is taking the original config for the node, converting it to a dict,
|
||||
# updating the config with new config passed in, then re-creating the
|
||||
# config from the dict in the node.
|
||||
@@ -367,7 +371,10 @@ class ConfiguredParser(
|
||||
|
||||
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
|
||||
try:
|
||||
context = self.render_with_context(node, config)
|
||||
provider = get_language_provider_by_name(node.language)
|
||||
provider.validate_raw_code(node)
|
||||
context = self._context_for(node, config)
|
||||
context = provider.update_context(node, config, context)
|
||||
self.update_parsed_node_config(node, config, context=context)
|
||||
except ValidationError as exc:
|
||||
# we got a ValidationError - probably bad types in config()
|
||||
@@ -426,6 +433,18 @@ class SimpleParser(
|
||||
return node
|
||||
|
||||
|
||||
# TODO: rename these to be more generic (not just SQL)
|
||||
# The full inheritance order for models is:
|
||||
# dbt.parser.models.ModelParser,
|
||||
# dbt.parser.base.SimpleSQLParser,
|
||||
# dbt.parser.base.SQLParser,
|
||||
# dbt.parser.base.ConfiguredParser,
|
||||
# dbt.parser.base.Parser,
|
||||
# dbt.parser.base.BaseParser,
|
||||
# These fine-grained class distinctions exist to support other parsers
|
||||
# e.g. SnapshotParser overrides both 'parse_file' + 'transform'
|
||||
|
||||
|
||||
class SQLParser(
|
||||
ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode]
|
||||
):
|
||||
|
||||
25
core/dbt/parser/languages/__init__.py
Normal file
25
core/dbt/parser/languages/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .provider import LanguageProvider # noqa
|
||||
from .jinja_sql import JinjaSQLProvider # noqa
|
||||
from .python import PythonProvider # noqa
|
||||
|
||||
# TODO: how to make this discovery/registration pluggable?
|
||||
from .prql import PrqlProvider # noqa
|
||||
from .ibis import IbisProvider # noqa
|
||||
|
||||
|
||||
def get_language_providers():
|
||||
return LanguageProvider.__subclasses__()
|
||||
|
||||
|
||||
def get_language_names():
|
||||
return [provider.name() for provider in get_language_providers()]
|
||||
|
||||
|
||||
def get_file_extensions():
|
||||
return [provider.file_ext() for provider in get_language_providers()]
|
||||
|
||||
|
||||
def get_language_provider_by_name(language_name: str) -> LanguageProvider:
|
||||
return next(
|
||||
iter(provider for provider in get_language_providers() if provider.name() == language_name)
|
||||
)
|
||||
121
core/dbt/parser/languages/ibis.py
Normal file
121
core/dbt/parser/languages/ibis.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import ibis
|
||||
import ast
|
||||
|
||||
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
|
||||
from dbt.parser.languages.python import PythonParseVisitor
|
||||
from dbt.contracts.graph.nodes import ManifestNode
|
||||
|
||||
from dbt.exceptions import PythonParsingError
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class IbisProvider(LanguageProvider):
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "ibis"
|
||||
|
||||
# TODO: how can we differentiate from python models?
|
||||
# can we support IPython-style magic, e.g. `%ibis`, at the top of the file?
|
||||
@classmethod
|
||||
def file_ext(self) -> str:
|
||||
return ".py"
|
||||
|
||||
@classmethod
|
||||
def compiled_language(self) -> str:
|
||||
return "sql"
|
||||
|
||||
@classmethod
|
||||
def validate_raw_code(self, node) -> None:
|
||||
# don't require the 'model' function for now
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
|
||||
"""
|
||||
List all references (refs, sources, configs) in a given block.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(node.raw_code, filename=node.original_file_path)
|
||||
except SyntaxError as exc:
|
||||
raise PythonParsingError(exc, node=node) from exc
|
||||
|
||||
# Only parse if AST tree has instructions in body
|
||||
if tree.body:
|
||||
# don't worry about the 'model' function for now
|
||||
# dbt_validator = PythonValidationVisitor()
|
||||
# dbt_validator.visit(tree)
|
||||
# dbt_validator.check_error(node)
|
||||
|
||||
dbt_parser = PythonParseVisitor(node)
|
||||
dbt_parser.visit(tree)
|
||||
return dbt_parser.dbt_function_calls
|
||||
else:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
# TODO: this is technically true, but Ibis won't actually use dbt's connection, it will make its own
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
resolved_references = self.get_resolved_references(node, context)
|
||||
|
||||
def ref(*args, dbt_load_df_function):
|
||||
refs = resolved_references["refs"]
|
||||
key = tuple(args)
|
||||
return dbt_load_df_function(refs[key])
|
||||
|
||||
def source(*args, dbt_load_df_function):
|
||||
sources = resolved_references["sources"]
|
||||
key = tuple(args)
|
||||
return dbt_load_df_function(sources[key])
|
||||
|
||||
config_dict = {}
|
||||
for key in node.config.get("config_keys_used", []):
|
||||
value = node.config[key]
|
||||
config_dict.update({key: value})
|
||||
|
||||
class config:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get(key, default=None):
|
||||
return config_dict.get(key, default)
|
||||
|
||||
class this:
|
||||
"""dbt.this() or dbt.this.identifier"""
|
||||
|
||||
database = node.database
|
||||
schema = node.schema
|
||||
identifier = node.identifier
|
||||
|
||||
def __repr__(self):
|
||||
return node.relation_name
|
||||
|
||||
class dbtObj:
|
||||
def __init__(self, load_df_function) -> None:
|
||||
self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function)
|
||||
self.ref = lambda *args: ref(*args, dbt_load_df_function=load_df_function)
|
||||
self.config = config
|
||||
self.this = this()
|
||||
# self.is_incremental = TODO
|
||||
|
||||
# https://ibis-project.org/docs/dev/backends/PostgreSQL/#ibis.backends.postgres.Backend.do_connect
|
||||
# TODO: this would need to live in the adapter somehow
|
||||
target = context["target"]
|
||||
con = ibis.postgres.connect(
|
||||
database=target["database"],
|
||||
user=target["user"],
|
||||
)
|
||||
|
||||
# use for dbt.ref(), dbt.source(), etc
|
||||
dbt = dbtObj(con.table) # noqa
|
||||
|
||||
# TODO: this is unsafe in so many ways
|
||||
exec(node.raw_code)
|
||||
compiled = str(eval(f"ibis.{context['target']['type']}.compile(model)"))
|
||||
|
||||
return compiled
|
||||
34
core/dbt/parser/languages/jinja_sql.py
Normal file
34
core/dbt/parser/languages/jinja_sql.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from dbt.clients import jinja
|
||||
from dbt.context.context_config import ContextConfig
|
||||
from dbt.parser.languages.provider import LanguageProvider
|
||||
from dbt.contracts.graph.nodes import ManifestNode
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class JinjaSQLProvider(LanguageProvider):
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "sql"
|
||||
|
||||
@classmethod
|
||||
def update_context(
|
||||
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# this goes through the process of rendering, but we don't keep the rendered result
|
||||
# the goal is to capture macros + update context as side effect
|
||||
jinja.get_rendered(node.raw_code, context, node, capture_macros=True)
|
||||
return context
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
compiled_code = jinja.get_rendered(
|
||||
node.raw_code,
|
||||
context,
|
||||
node,
|
||||
)
|
||||
return compiled_code
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
return True
|
||||
97
core/dbt/parser/languages/provider.py
Normal file
97
core/dbt/parser/languages/provider.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Tuple, List, Any
|
||||
import abc
|
||||
|
||||
# for type hints
|
||||
from dbt.contracts.graph.nodes import RefArgs, ManifestNode
|
||||
from dbt.context.providers import RelationProxy
|
||||
from dbt.context.context_config import ContextConfig
|
||||
|
||||
# TODO rework these types now that 'ref' accepts a keyword argument ('v' or 'version')
|
||||
dbt_function_calls = List[Tuple[str, List[str], Dict[str, Any]]]
|
||||
references_type = Dict[str, Dict[Tuple[str, ...], RelationProxy]]
|
||||
|
||||
|
||||
class LanguageProvider(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
A LanguageProvider is a class that can parse & compile a given language.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def file_ext(self) -> str:
|
||||
return f".{self.name()}"
|
||||
|
||||
@classmethod
|
||||
def compiled_language(self) -> str:
|
||||
return self.name()
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
# TODO add type hints
|
||||
def extract_dbt_function_calls(self, node: Any) -> dbt_function_calls:
|
||||
"""
|
||||
List all dbt function calls (ref, source, config) and their args/kwargs
|
||||
"""
|
||||
raise NotImplementedError("extract_dbt_function_calls")
|
||||
|
||||
@classmethod
|
||||
def validate_raw_code(self, node: Any) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def update_context(
|
||||
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
dbt_function_calls = cls.extract_dbt_function_calls(node)
|
||||
config_keys_used = []
|
||||
for (func, args, kwargs) in dbt_function_calls:
|
||||
if func == "get":
|
||||
config_keys_used.append(args[0])
|
||||
continue
|
||||
|
||||
context[func](*args, **kwargs)
|
||||
if config_keys_used:
|
||||
# this is being used in macro build_config_dict
|
||||
context["config"](config_keys_used=config_keys_used)
|
||||
return context
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
"""
|
||||
Does this modeling language support introspective queries (requiring a database connection)
|
||||
at compile time?
|
||||
"""
|
||||
raise NotImplementedError("needs_compile_time_connection")
|
||||
|
||||
@classmethod
|
||||
def get_resolved_references(
|
||||
self, node: ManifestNode, context: Dict[str, Any]
|
||||
) -> references_type:
|
||||
resolved_references: references_type = {
|
||||
"sources": {},
|
||||
"refs": {},
|
||||
}
|
||||
# TODO: do we need to support custom 'ref' + 'source' resolution logic for non-JinjaSQL languages?
|
||||
# i.e. user-defined 'ref' + 'source' macros -- this approach will not work for that
|
||||
refs: List[RefArgs] = node.refs
|
||||
sources: List[List[str]] = node.sources
|
||||
for ref in refs:
|
||||
resolved_ref: RelationProxy = context["ref"](*ref)
|
||||
resolved_references["refs"].update({tuple(ref): resolved_ref})
|
||||
for source in sources:
|
||||
resolved_src: RelationProxy = context["source"](*source)
|
||||
resolved_references["sources"].update({tuple(source): resolved_src})
|
||||
return resolved_references
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
For a given ManifestNode, return its compiled code.
|
||||
"""
|
||||
raise NotImplementedError("get_compiled_code")
|
||||
174
core/dbt/parser/languages/prql.py
Normal file
174
core/dbt/parser/languages/prql.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
This will be in the `dbt-prql` package, but including here during inital code review, so
|
||||
we can test it without coordinating dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls, references_type
|
||||
from dbt.contracts.graph.nodes import ManifestNode
|
||||
|
||||
|
||||
# import prql_python
|
||||
# This mocks the prqlc output for two cases which we currently use in tests, so we can
|
||||
# test this without configuring dependencies. (Obv fix as we expand the tests, way
|
||||
# before we merge.)
|
||||
class prql_python: # type: ignore
|
||||
@staticmethod
|
||||
def to_sql(prql) -> str:
|
||||
|
||||
query_1 = "from employees"
|
||||
|
||||
query_1_compiled = """
|
||||
SELECT
|
||||
employees.*
|
||||
FROM
|
||||
employees
|
||||
""".strip()
|
||||
|
||||
query_2 = """
|
||||
from (dbt source.whatever.some_tbl)
|
||||
join (dbt ref.test.foo) [id]
|
||||
filter salary > 100
|
||||
""".strip()
|
||||
|
||||
# hard coded for Jerco's Postgres database
|
||||
query_2_resolved = """
|
||||
from ("jerco"."salesforce"."in_process")
|
||||
join ("jerco"."dbt_jcohen"."foo") [id]
|
||||
filter salary > 100
|
||||
""".strip()
|
||||
|
||||
query_2_compiled = """
|
||||
SELECT
|
||||
"jerco"."whatever"."some_tbl".*,
|
||||
"jerco"."dbt_jcohen"."foo".*,
|
||||
id
|
||||
FROM
|
||||
"jerco"."salesforce"."in_process"
|
||||
JOIN "jerco"."dbt_jcohen"."foo" USING(id)
|
||||
WHERE
|
||||
salary > 100
|
||||
""".strip()
|
||||
|
||||
lookup = dict(
|
||||
{
|
||||
query_1: query_1_compiled,
|
||||
query_2: query_2_compiled,
|
||||
query_2_resolved: query_2_compiled,
|
||||
}
|
||||
)
|
||||
return lookup[prql]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
word_regex = r"[\w\.\-_]+"
|
||||
# TODO support single-argument form of 'ref'
|
||||
references_regex = rf"\bdbt `?(\w+)\.({word_regex})\.({word_regex})`?"
|
||||
|
||||
|
||||
def hack_compile(prql: str, references: references_type, dialect: str) -> str:
|
||||
"""
|
||||
>>> print(compile(
|
||||
... "from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]",
|
||||
... references=dict(
|
||||
... sources={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'},
|
||||
... refs={('foo', 'bar'): 'foo_schema.bar_tbl'}
|
||||
... )
|
||||
... ))
|
||||
SELECT
|
||||
"{{ source('salesforce', 'in_process') }}".*,
|
||||
"{{ ref('foo', 'bar') }}".*,
|
||||
id
|
||||
FROM
|
||||
{{ source('salesforce', 'in_process') }}
|
||||
JOIN {{ ref('foo', 'bar') }} USING(id)
|
||||
"""
|
||||
|
||||
subs = []
|
||||
for k, v in references["sources"].items():
|
||||
key = ".".join(k)
|
||||
lookup = f"dbt source.{key}"
|
||||
subs.append((lookup, str(v)))
|
||||
|
||||
for k, v in references["refs"].items():
|
||||
key = ".".join(k)
|
||||
lookup = f"dbt ref.{key}"
|
||||
subs.append((lookup, str(v)))
|
||||
|
||||
for lookup, resolved in subs:
|
||||
prql = prql.replace(lookup, resolved)
|
||||
|
||||
sql = prql_python.to_sql(prql)
|
||||
return sql
|
||||
|
||||
|
||||
def hack_list_references(prql):
|
||||
"""
|
||||
List all references (e.g. sources / refs) in a given block.
|
||||
|
||||
We need to decide:
|
||||
|
||||
— What should prqlc return given `dbt source.foo.bar`, so dbt-prql can find the
|
||||
references?
|
||||
— Should it just fill in something that looks like jinja for expediancy? (We
|
||||
don't support jinja though)
|
||||
|
||||
>>> references = list_references("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
|
||||
>>> dict(references)
|
||||
{'source': [('salesforce', 'in_process')], 'ref': [('foo', 'bar')]}
|
||||
"""
|
||||
out = []
|
||||
for t, package, model in _hack_references_of_prql_query(prql):
|
||||
out.append((t, [package, model], {}))
|
||||
return out
|
||||
|
||||
|
||||
def _hack_references_of_prql_query(prql) -> list[tuple[str, str, str]]:
|
||||
"""
|
||||
List the references in a prql query.
|
||||
|
||||
This would be implemented by prqlc.
|
||||
|
||||
>>> _hack_references_of_prql_query("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
|
||||
[('source', 'salesforce', 'in_process'), ('ref', 'foo', 'bar')]
|
||||
"""
|
||||
return re.findall(references_regex, prql)
|
||||
|
||||
|
||||
class PrqlProvider(LanguageProvider):
|
||||
def __init__(self) -> None:
|
||||
# TODO: Uncomment when dbt-prql is released
|
||||
# if not dbt_prql:
|
||||
# raise ImportError(
|
||||
# "dbt_prql is required and not found; try running `pip install dbt_prql`"
|
||||
# )
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "prql"
|
||||
|
||||
@classmethod
|
||||
def compiled_language(self) -> str:
|
||||
return "sql"
|
||||
|
||||
@classmethod
|
||||
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
|
||||
return hack_list_references(node.raw_code)
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
dialect = context["target"]["type"]
|
||||
resolved_references = self.get_resolved_references(node, context)
|
||||
return hack_compile(node.raw_code, references=resolved_references, dialect=dialect)
|
||||
219
core/dbt/parser/languages/python.py
Normal file
219
core/dbt/parser/languages/python.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import ast
|
||||
|
||||
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
|
||||
from dbt.exceptions import (
|
||||
UndefinedMacroError,
|
||||
ParsingError,
|
||||
PythonLiteralEvalError,
|
||||
PythonParsingError,
|
||||
)
|
||||
from dbt.contracts.graph.nodes import ManifestNode
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
dbt_function_key_words = set(["ref", "source", "config", "get"])
|
||||
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
|
||||
|
||||
|
||||
class PythonValidationVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dbt_errors = []
|
||||
self.num_model_def = 0
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if node.name == "model":
|
||||
self.num_model_def += 1
|
||||
if node.args.args and not node.args.args[0].arg == "dbt":
|
||||
self.dbt_errors.append("'dbt' not provided for model as the first argument")
|
||||
if len(node.args.args) != 2:
|
||||
self.dbt_errors.append(
|
||||
"model function should have two args, `dbt` and a session to current warehouse"
|
||||
)
|
||||
# check we have a return and only one
|
||||
if not isinstance(node.body[-1], ast.Return) or isinstance(
|
||||
node.body[-1].value, ast.Tuple
|
||||
):
|
||||
self.dbt_errors.append(
|
||||
"In current version, model function should return only one dataframe object"
|
||||
)
|
||||
|
||||
def check_error(self, node):
|
||||
if self.num_model_def != 1:
|
||||
raise ParsingError(
|
||||
f"dbt allows exactly one model defined per python file, found {self.num_model_def}",
|
||||
node=node,
|
||||
)
|
||||
|
||||
if len(self.dbt_errors) != 0:
|
||||
raise ParsingError("\n".join(self.dbt_errors), node=node)
|
||||
|
||||
|
||||
class PythonParseVisitor(ast.NodeVisitor):
|
||||
def __init__(self, dbt_node):
|
||||
super().__init__()
|
||||
|
||||
self.dbt_node = dbt_node
|
||||
self.dbt_function_calls = []
|
||||
self.packages = []
|
||||
|
||||
@classmethod
|
||||
def _flatten_attr(cls, node):
|
||||
if isinstance(node, ast.Attribute):
|
||||
return str(cls._flatten_attr(node.value)) + "." + node.attr
|
||||
elif isinstance(node, ast.Name):
|
||||
return str(node.id)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _safe_eval(self, node):
|
||||
try:
|
||||
return ast.literal_eval(node)
|
||||
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
|
||||
raise PythonLiteralEvalError(exc, node=self.dbt_node) from exc
|
||||
|
||||
def _get_call_literals(self, node):
|
||||
# List of literals
|
||||
arg_literals = []
|
||||
kwarg_literals = {}
|
||||
|
||||
# TODO : Make sure this throws (and that we catch it)
|
||||
# for non-literal inputs
|
||||
for arg in node.args:
|
||||
rendered = self._safe_eval(arg)
|
||||
arg_literals.append(rendered)
|
||||
|
||||
for keyword in node.keywords:
|
||||
key = keyword.arg
|
||||
rendered = self._safe_eval(keyword.value)
|
||||
kwarg_literals[key] = rendered
|
||||
|
||||
return arg_literals, kwarg_literals
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
# check weather the current call could be a dbt function call
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
|
||||
func_name = self._flatten_attr(node.func)
|
||||
# check weather the current call really is a dbt function call
|
||||
if func_name in dbt_function_full_names:
|
||||
# drop the dot-dbt prefix
|
||||
func_name = func_name.split(".")[-1]
|
||||
args, kwargs = self._get_call_literals(node)
|
||||
self.dbt_function_calls.append((func_name, args, kwargs))
|
||||
|
||||
# no matter what happened above, we should keep visiting the rest of the tree
|
||||
# visit args and kwargs to see if there's call in it
|
||||
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
|
||||
if isinstance(obj, ast.Call):
|
||||
self.visit_Call(obj)
|
||||
# support dbt.ref in list args, kwargs
|
||||
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
|
||||
for el in obj.elts:
|
||||
if isinstance(el, ast.Call):
|
||||
self.visit_Call(el)
|
||||
# support dbt.ref in dict args, kwargs
|
||||
elif isinstance(obj, ast.Dict):
|
||||
for value in obj.values:
|
||||
if isinstance(value, ast.Call):
|
||||
self.visit_Call(value)
|
||||
# visit node.func.value if we are at an call attr
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
self.attribute_helper(node.func)
|
||||
|
||||
def attribute_helper(self, node: ast.Attribute) -> None:
|
||||
while isinstance(node, ast.Attribute):
|
||||
node = node.value # type: ignore
|
||||
if isinstance(node, ast.Call):
|
||||
self.visit_Call(node)
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
for n in node.names:
|
||||
self.packages.append(n.name.split(".")[0])
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if node.module:
|
||||
self.packages.append(node.module.split(".")[0])
|
||||
|
||||
|
||||
def verify_python_model_code(node):
|
||||
from dbt.clients.jinja import get_rendered
|
||||
|
||||
# TODO: add a test for this
|
||||
try:
|
||||
rendered_python = get_rendered(
|
||||
node.raw_code,
|
||||
{},
|
||||
node,
|
||||
)
|
||||
if rendered_python != node.raw_code:
|
||||
raise ParsingError("")
|
||||
except (UndefinedMacroError, ParsingError):
|
||||
raise ParsingError("No jinja in python model code is allowed", node=node)
|
||||
|
||||
|
||||
class PythonProvider(LanguageProvider):
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "python"
|
||||
|
||||
@classmethod
|
||||
def file_ext(self) -> str:
|
||||
return ".py"
|
||||
|
||||
@classmethod
|
||||
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
|
||||
"""
|
||||
List all references (refs, sources, configs) in a given block.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(node.raw_code, filename=node.original_file_path)
|
||||
except SyntaxError as exc:
|
||||
raise PythonParsingError(exc, node=node) from exc
|
||||
|
||||
# Only parse if AST tree has instructions in body
|
||||
if tree.body:
|
||||
# We are doing a validator and a parser because visit_FunctionDef in parser
|
||||
# would actually make the parser not doing the visit_Calls any more
|
||||
dbt_validator = PythonValidationVisitor()
|
||||
dbt_validator.visit(tree)
|
||||
dbt_validator.check_error(node)
|
||||
|
||||
dbt_parser = PythonParseVisitor(node)
|
||||
dbt_parser.visit(tree)
|
||||
return dbt_parser.dbt_function_calls
|
||||
else:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def validate_raw_code(self, node) -> None:
|
||||
from dbt.clients.jinja import get_rendered
|
||||
|
||||
# TODO: add a test for this
|
||||
try:
|
||||
rendered_python = get_rendered(
|
||||
node.raw_code,
|
||||
{},
|
||||
node,
|
||||
)
|
||||
if rendered_python != node.raw_code:
|
||||
raise ParsingError("")
|
||||
except (UndefinedMacroError, ParsingError):
|
||||
raise ParsingError("No jinja in python model code is allowed", node=node)
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
# needed for compilation - bad!!
|
||||
from dbt.clients import jinja
|
||||
|
||||
# TODO: rewrite 'py_script_postfix' in Python instead of Jinja, use get_resolved_references
|
||||
postfix = jinja.get_rendered(
|
||||
"{{ py_script_postfix(model) }}",
|
||||
context,
|
||||
node,
|
||||
)
|
||||
# we should NOT jinja render the python model's 'raw code'
|
||||
return f"{node.raw_code}\n\n{postfix}"
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
return False
|
||||
@@ -8,7 +8,6 @@ from dbt.flags import get_flags
|
||||
from dbt.node_types import NodeType, ModelLanguage
|
||||
from dbt.parser.base import SimpleSQLParser
|
||||
from dbt.parser.search import FileBlock
|
||||
from dbt.clients.jinja import get_rendered
|
||||
import dbt.tracking as tracking
|
||||
from dbt import utils
|
||||
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
|
||||
@@ -17,154 +16,6 @@ from itertools import chain
|
||||
import random
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
# New for Python models :p
|
||||
import ast
|
||||
from dbt.dataclass_schema import ValidationError
|
||||
from dbt.exceptions import (
|
||||
ModelConfigError,
|
||||
ParsingError,
|
||||
PythonLiteralEvalError,
|
||||
PythonParsingError,
|
||||
UndefinedMacroError,
|
||||
)
|
||||
|
||||
dbt_function_key_words = set(["ref", "source", "config", "get"])
|
||||
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
|
||||
|
||||
|
||||
class PythonValidationVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dbt_errors = []
|
||||
self.num_model_def = 0
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if node.name == "model":
|
||||
self.num_model_def += 1
|
||||
if node.args.args and not node.args.args[0].arg == "dbt":
|
||||
self.dbt_errors.append("'dbt' not provided for model as the first argument")
|
||||
if len(node.args.args) != 2:
|
||||
self.dbt_errors.append(
|
||||
"model function should have two args, `dbt` and a session to current warehouse"
|
||||
)
|
||||
# check we have a return and only one
|
||||
if not isinstance(node.body[-1], ast.Return) or isinstance(
|
||||
node.body[-1].value, ast.Tuple
|
||||
):
|
||||
self.dbt_errors.append(
|
||||
"In current version, model function should return only one dataframe object"
|
||||
)
|
||||
|
||||
def check_error(self, node):
|
||||
if self.num_model_def != 1:
|
||||
raise ParsingError(
|
||||
f"dbt allows exactly one model defined per python file, found {self.num_model_def}",
|
||||
node=node,
|
||||
)
|
||||
|
||||
if len(self.dbt_errors) != 0:
|
||||
raise ParsingError("\n".join(self.dbt_errors), node=node)
|
||||
|
||||
|
||||
class PythonParseVisitor(ast.NodeVisitor):
|
||||
def __init__(self, dbt_node):
|
||||
super().__init__()
|
||||
|
||||
self.dbt_node = dbt_node
|
||||
self.dbt_function_calls = []
|
||||
self.packages = []
|
||||
|
||||
@classmethod
|
||||
def _flatten_attr(cls, node):
|
||||
if isinstance(node, ast.Attribute):
|
||||
return str(cls._flatten_attr(node.value)) + "." + node.attr
|
||||
elif isinstance(node, ast.Name):
|
||||
return str(node.id)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _safe_eval(self, node):
|
||||
try:
|
||||
return ast.literal_eval(node)
|
||||
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
|
||||
raise PythonLiteralEvalError(exc, node=self.dbt_node) from exc
|
||||
|
||||
def _get_call_literals(self, node):
|
||||
# List of literals
|
||||
arg_literals = []
|
||||
kwarg_literals = {}
|
||||
|
||||
# TODO : Make sure this throws (and that we catch it)
|
||||
# for non-literal inputs
|
||||
for arg in node.args:
|
||||
rendered = self._safe_eval(arg)
|
||||
arg_literals.append(rendered)
|
||||
|
||||
for keyword in node.keywords:
|
||||
key = keyword.arg
|
||||
rendered = self._safe_eval(keyword.value)
|
||||
kwarg_literals[key] = rendered
|
||||
|
||||
return arg_literals, kwarg_literals
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
# check weather the current call could be a dbt function call
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
|
||||
func_name = self._flatten_attr(node.func)
|
||||
# check weather the current call really is a dbt function call
|
||||
if func_name in dbt_function_full_names:
|
||||
# drop the dot-dbt prefix
|
||||
func_name = func_name.split(".")[-1]
|
||||
args, kwargs = self._get_call_literals(node)
|
||||
self.dbt_function_calls.append((func_name, args, kwargs))
|
||||
|
||||
# no matter what happened above, we should keep visiting the rest of the tree
|
||||
# visit args and kwargs to see if there's call in it
|
||||
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
|
||||
if isinstance(obj, ast.Call):
|
||||
self.visit_Call(obj)
|
||||
# support dbt.ref in list args, kwargs
|
||||
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
|
||||
for el in obj.elts:
|
||||
if isinstance(el, ast.Call):
|
||||
self.visit_Call(el)
|
||||
# support dbt.ref in dict args, kwargs
|
||||
elif isinstance(obj, ast.Dict):
|
||||
for value in obj.values:
|
||||
if isinstance(value, ast.Call):
|
||||
self.visit_Call(value)
|
||||
# visit node.func.value if we are at an call attr
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
self.attribute_helper(node.func)
|
||||
|
||||
def attribute_helper(self, node: ast.Attribute) -> None:
|
||||
while isinstance(node, ast.Attribute):
|
||||
node = node.value # type: ignore
|
||||
if isinstance(node, ast.Call):
|
||||
self.visit_Call(node)
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
for n in node.names:
|
||||
self.packages.append(n.name.split(".")[0])
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if node.module:
|
||||
self.packages.append(node.module.split(".")[0])
|
||||
|
||||
|
||||
def verify_python_model_code(node):
|
||||
# TODO: add a test for this
|
||||
try:
|
||||
rendered_python = get_rendered(
|
||||
node.raw_code,
|
||||
{},
|
||||
node,
|
||||
)
|
||||
if rendered_python != node.raw_code:
|
||||
raise ParsingError("")
|
||||
except (UndefinedMacroError, ParsingError):
|
||||
raise ParsingError("No jinja in python model code is allowed", node=node)
|
||||
|
||||
|
||||
class ModelParser(SimpleSQLParser[ModelNode]):
|
||||
def parse_from_dict(self, dct, validate=True) -> ModelNode:
|
||||
@@ -180,70 +31,16 @@ class ModelParser(SimpleSQLParser[ModelNode]):
|
||||
def get_compiled_path(cls, block: FileBlock):
|
||||
return block.path.relative_path
|
||||
|
||||
def parse_python_model(self, node, config, context):
|
||||
config_keys_used = []
|
||||
config_keys_defaults = []
|
||||
|
||||
try:
|
||||
tree = ast.parse(node.raw_code, filename=node.original_file_path)
|
||||
except SyntaxError as exc:
|
||||
raise PythonParsingError(exc, node=node) from exc
|
||||
|
||||
# Only parse if AST tree has instructions in body
|
||||
if tree.body:
|
||||
# We are doing a validator and a parser because visit_FunctionDef in parser
|
||||
# would actually make the parser not doing the visit_Calls any more
|
||||
dbt_validator = PythonValidationVisitor()
|
||||
dbt_validator.visit(tree)
|
||||
dbt_validator.check_error(node)
|
||||
|
||||
dbt_parser = PythonParseVisitor(node)
|
||||
dbt_parser.visit(tree)
|
||||
|
||||
for (func, args, kwargs) in dbt_parser.dbt_function_calls:
|
||||
if func == "get":
|
||||
num_args = len(args)
|
||||
if num_args == 0:
|
||||
raise ParsingError(
|
||||
"dbt.config.get() requires at least one argument",
|
||||
node=node,
|
||||
)
|
||||
if num_args > 2:
|
||||
raise ParsingError(
|
||||
f"dbt.config.get() takes at most 2 arguments ({num_args} given)",
|
||||
node=node,
|
||||
)
|
||||
key = args[0]
|
||||
default_value = args[1] if num_args == 2 else None
|
||||
config_keys_used.append(key)
|
||||
config_keys_defaults.append(default_value)
|
||||
continue
|
||||
|
||||
context[func](*args, **kwargs)
|
||||
|
||||
if config_keys_used:
|
||||
# this is being used in macro build_config_dict
|
||||
context["config"](
|
||||
config_keys_used=config_keys_used,
|
||||
config_keys_defaults=config_keys_defaults,
|
||||
)
|
||||
|
||||
def render_update(self, node: ModelNode, config: ContextConfig) -> None:
|
||||
self.manifest._parsing_info.static_analysis_path_count += 1
|
||||
# TODO
|
||||
if node.language != ModelLanguage.sql:
|
||||
super().render_update(node, config)
|
||||
|
||||
# TODO move all the logic below into JinjaSQL provider
|
||||
flags = get_flags()
|
||||
if node.language == ModelLanguage.python:
|
||||
try:
|
||||
verify_python_model_code(node)
|
||||
context = self._context_for(node, config)
|
||||
self.parse_python_model(node, config, context)
|
||||
self.update_parsed_node_config(node, config, context=context)
|
||||
self.manifest._parsing_info.static_analysis_path_count += 1
|
||||
|
||||
except ValidationError as exc:
|
||||
# we got a ValidationError - probably bad types in config()
|
||||
raise ModelConfigError(exc, node=node) from exc
|
||||
return
|
||||
|
||||
elif not flags.STATIC_PARSER:
|
||||
if not flags.STATIC_PARSER:
|
||||
# jinja rendering
|
||||
super().render_update(node, config)
|
||||
fire_event(
|
||||
|
||||
@@ -13,6 +13,7 @@ from dbt.contracts.files import (
|
||||
)
|
||||
from dbt.config import Project
|
||||
from dbt.dataclass_schema import dbtClassMixin
|
||||
from dbt.parser.languages import get_file_extensions
|
||||
from dbt.parser.schemas import yaml_from_file, schema_file_keys
|
||||
from dbt.exceptions import ParsingError
|
||||
from dbt.parser.search import filesystem_search
|
||||
@@ -366,6 +367,7 @@ class ReadFilesFromDiff:
|
||||
|
||||
|
||||
def get_file_types_for_project(project):
|
||||
model_extensions = get_file_extensions()
|
||||
file_types = {
|
||||
ParseFileType.Macro: {
|
||||
"paths": project.macro_paths,
|
||||
@@ -374,7 +376,7 @@ def get_file_types_for_project(project):
|
||||
},
|
||||
ParseFileType.Model: {
|
||||
"paths": project.model_paths,
|
||||
"extensions": [".sql", ".py"],
|
||||
"extensions": model_extensions,
|
||||
"parser": "ModelParser",
|
||||
},
|
||||
ParseFileType.Snapshot: {
|
||||
|
||||
@@ -275,6 +275,7 @@ class SchemaParser(SimpleParser[GenericTestBlock, GenericTestNode]):
|
||||
path=path,
|
||||
original_file_path=target.original_file_path,
|
||||
raw_code=raw_code,
|
||||
language="sql",
|
||||
)
|
||||
raise TestConfigError(exc, node)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from typing import Type, Union, Dict, Any, Optional
|
||||
|
||||
@@ -309,9 +308,18 @@ class BaseRunner(metaclass=ABCMeta):
|
||||
failures=None,
|
||||
)
|
||||
|
||||
# some modeling languages don't need database connections for compilation,
|
||||
# only for runtime (materialization)
|
||||
def needs_connection(self):
|
||||
return True
|
||||
|
||||
def compile_and_execute(self, manifest, ctx):
|
||||
from contextlib import nullcontext
|
||||
|
||||
result = None
|
||||
with self.adapter.connection_for(self.node) if get_flags().INTROSPECT else nullcontext():
|
||||
with self.adapter.connection_for(
|
||||
self.node
|
||||
) if self.needs_connection() and get_flags().INTROSPECT else nullcontext():
|
||||
ctx.node.update_event_status(node_status=RunningStatus.Compiling)
|
||||
fire_event(
|
||||
NodeCompiling(
|
||||
|
||||
@@ -22,6 +22,12 @@ class CompileRunner(BaseRunner):
|
||||
def after_execute(self, result):
|
||||
pass
|
||||
|
||||
def needs_connection(self):
|
||||
from dbt.parser.languages import get_language_provider_by_name
|
||||
|
||||
provider = get_language_provider_by_name(self.node.language)
|
||||
return provider.needs_compile_time_connection()
|
||||
|
||||
def execute(self, compiled_node, manifest):
|
||||
return RunResult(
|
||||
node=compiled_node,
|
||||
|
||||
@@ -173,6 +173,9 @@ def _validate_materialization_relations_dict(inp: Dict[Any, Any], model) -> List
|
||||
|
||||
|
||||
class ModelRunner(CompileRunner):
|
||||
def needs_connection(self):
|
||||
return True
|
||||
|
||||
def get_node_representation(self):
|
||||
display_quote_policy = {"database": False, "schema": False, "identifier": False}
|
||||
relation = self.adapter.Relation.create_from(
|
||||
@@ -278,12 +281,12 @@ class ModelRunner(CompileRunner):
|
||||
context_config = context["config"]
|
||||
|
||||
mat_has_supported_langs = hasattr(materialization_macro, "supported_languages")
|
||||
model_lang_supported = model.language in materialization_macro.supported_languages
|
||||
model_lang_supported = model.compiled_language in materialization_macro.supported_languages
|
||||
if mat_has_supported_langs and not model_lang_supported:
|
||||
str_langs = [str(lang) for lang in materialization_macro.supported_languages]
|
||||
raise DbtValidationError(
|
||||
f'Materialization "{materialization_macro.name}" only supports languages {str_langs}; '
|
||||
f'got "{model.language}"'
|
||||
f'got "{model.language}" which compiles to "{model.compiled_language}"'
|
||||
)
|
||||
|
||||
hook_ctx = self.adapter.pre_model_hook(context_config)
|
||||
|
||||
@@ -63,6 +63,9 @@ class GraphTest(unittest.TestCase):
|
||||
self.filesystem_search = patch("dbt.parser.read_files.filesystem_search")
|
||||
|
||||
def mock_filesystem_search(project, relative_dirs, extension, ignore_spec):
|
||||
# Adding in `and "prql" not in extension` will cause a bunch of tests to
|
||||
# fail; need to understand more on how these are constructed to debug.
|
||||
# Possibly `sql not in extension` is a way of having it only run once.
|
||||
if "sql" not in extension:
|
||||
return []
|
||||
if "models" not in relative_dirs:
|
||||
@@ -147,16 +150,18 @@ class GraphTest(unittest.TestCase):
|
||||
return dbt.compilation.Compiler(project)
|
||||
|
||||
def use_models(self, models):
|
||||
for k, v in models.items():
|
||||
for k, (source, lang) in models.items():
|
||||
path = FilePath(
|
||||
searched_path="models",
|
||||
project_root=os.path.normcase(os.getcwd()),
|
||||
relative_path="{}.sql".format(k),
|
||||
relative_path=f"{k}.{lang}",
|
||||
modification_time=0.0,
|
||||
)
|
||||
# FileHash can't be empty or 'search_key' will be None
|
||||
source_file = SourceFile(path=path, checksum=FileHash.from_contents("abc"))
|
||||
source_file.contents = v
|
||||
source_file = SourceFile(
|
||||
path=path, checksum=FileHash.from_contents("abc"), language=lang
|
||||
)
|
||||
source_file.contents = source
|
||||
self.mock_models.append(source_file)
|
||||
|
||||
def load_manifest(self, config):
|
||||
@@ -170,7 +175,7 @@ class GraphTest(unittest.TestCase):
|
||||
def test__single_model(self):
|
||||
self.use_models(
|
||||
{
|
||||
"model_one": "select * from events",
|
||||
"model_one": ("select * from events", "sql"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -187,8 +192,8 @@ class GraphTest(unittest.TestCase):
|
||||
def test__two_models_simple_ref(self):
|
||||
self.use_models(
|
||||
{
|
||||
"model_one": "select * from events",
|
||||
"model_two": "select * from {{ref('model_one')}}",
|
||||
"model_one": ("select * from events", "sql"),
|
||||
"model_two": ("select * from {{ref('model_one')}}", "sql"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -218,10 +223,10 @@ class GraphTest(unittest.TestCase):
|
||||
def test__model_materializations(self):
|
||||
self.use_models(
|
||||
{
|
||||
"model_one": "select * from events",
|
||||
"model_two": "select * from {{ref('model_one')}}",
|
||||
"model_three": "select * from events",
|
||||
"model_four": "select * from events",
|
||||
"model_one": ("select * from events", "sql"),
|
||||
"model_two": ("select * from {{ref('model_one')}}", "sql"),
|
||||
"model_three": ("select * from events", "sql"),
|
||||
"model_four": ("select * from events", "sql"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -252,7 +257,11 @@ class GraphTest(unittest.TestCase):
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test__model_incremental(self):
|
||||
self.use_models({"model_one": "select * from events"})
|
||||
self.use_models(
|
||||
{
|
||||
"model_one": ("select * from events", "sql"),
|
||||
}
|
||||
)
|
||||
|
||||
cfg = {
|
||||
"models": {
|
||||
@@ -277,14 +286,17 @@ class GraphTest(unittest.TestCase):
|
||||
def test__dependency_list(self):
|
||||
self.use_models(
|
||||
{
|
||||
"model_1": "select * from events",
|
||||
"model_2": 'select * from {{ ref("model_1") }}',
|
||||
"model_3": """
|
||||
"model_1": ("select * from events", "sql"),
|
||||
"model_2": ('select * from {{ ref("model_1") }}', "sql"),
|
||||
"model_3": (
|
||||
"""
|
||||
select * from {{ ref("model_1") }}
|
||||
union all
|
||||
select * from {{ ref("model_2") }}
|
||||
""",
|
||||
"model_4": 'select * from {{ ref("model_3") }}',
|
||||
"sql",
|
||||
),
|
||||
"model_4": ('select * from {{ ref("model_3") }}', "sql"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -344,3 +356,20 @@ class GraphTest(unittest.TestCase):
|
||||
manifest.metadata.dbt_version = "99999.99.99"
|
||||
is_partial_parsable, _ = loader.is_partial_parsable(manifest)
|
||||
self.assertFalse(is_partial_parsable)
|
||||
|
||||
def test_models_prql(self):
|
||||
self.use_models(
|
||||
{
|
||||
"model_prql": ("from employees", "prql"),
|
||||
}
|
||||
)
|
||||
|
||||
config = self.get_config()
|
||||
manifest = self.load_manifest(config)
|
||||
|
||||
compiler = self.get_compiler(config)
|
||||
linker = compiler.compile(manifest)
|
||||
|
||||
self.assertEqual(list(linker.nodes()), ["model.test_models_compile.model_prql"])
|
||||
|
||||
self.assertEqual(list(linker.edges()), [])
|
||||
|
||||
@@ -49,6 +49,7 @@ import dbt.contracts.graph.nodes
|
||||
from .utils import replace_config
|
||||
|
||||
|
||||
# TODO: possibly change `sql` arg to `code`
|
||||
def make_model(
|
||||
pkg,
|
||||
name,
|
||||
@@ -63,6 +64,7 @@ def make_model(
|
||||
depends_on_macros=None,
|
||||
version=None,
|
||||
latest_version=None,
|
||||
language="sql",
|
||||
):
|
||||
if refs is None:
|
||||
refs = []
|
||||
@@ -71,7 +73,7 @@ def make_model(
|
||||
if tags is None:
|
||||
tags = []
|
||||
if path is None:
|
||||
path = f"{name}.sql"
|
||||
path = f"{name}.{language}"
|
||||
if alias is None:
|
||||
alias = name
|
||||
if config_kwargs is None:
|
||||
@@ -97,7 +99,7 @@ def make_model(
|
||||
depends_on_nodes.append(src.unique_id)
|
||||
|
||||
return ModelNode(
|
||||
language="sql",
|
||||
language=language,
|
||||
raw_code=sql,
|
||||
database="dbt",
|
||||
schema="dbt_schema",
|
||||
@@ -511,6 +513,19 @@ def table_model(ephemeral_model):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table_model_prql(seed):
|
||||
return make_model(
|
||||
"pkg",
|
||||
"table_model_prql",
|
||||
"from (dbt source employees)",
|
||||
config_kwargs={"materialized": "table"},
|
||||
refs=[seed],
|
||||
tags=[],
|
||||
path="subdirectory/table_model.prql",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table_model_py(seed):
|
||||
return make_model(
|
||||
@@ -728,6 +743,7 @@ def manifest(
|
||||
ephemeral_model,
|
||||
view_model,
|
||||
table_model,
|
||||
table_model_prql,
|
||||
table_model_py,
|
||||
table_model_csv,
|
||||
ext_source,
|
||||
@@ -828,6 +844,7 @@ def test_select_fqn(manifest):
|
||||
"versioned_model.v3",
|
||||
"versioned_model.v4",
|
||||
"table_model",
|
||||
"table_model_prql",
|
||||
"table_model_py",
|
||||
"table_model_csv",
|
||||
"view_model",
|
||||
@@ -864,6 +881,7 @@ def test_select_fqn(manifest):
|
||||
# single wildcard
|
||||
assert search_manifest_using_method(manifest, method, "pkg.t*") == {
|
||||
"table_model",
|
||||
"table_model_prql",
|
||||
"table_model_py",
|
||||
"table_model_csv",
|
||||
}
|
||||
@@ -1001,6 +1019,9 @@ def test_select_file(manifest):
|
||||
assert search_manifest_using_method(manifest, method, "table_model.sql") == {"table_model"}
|
||||
assert search_manifest_using_method(manifest, method, "table_model.py") == {"table_model_py"}
|
||||
assert search_manifest_using_method(manifest, method, "table_model.csv") == {"table_model_csv"}
|
||||
assert search_manifest_using_method(manifest, method, "table_model.prql") == {
|
||||
"table_model_prql"
|
||||
}
|
||||
assert search_manifest_using_method(manifest, method, "union_model.sql") == {
|
||||
"union_model",
|
||||
"mynamespace.union_model",
|
||||
@@ -1023,6 +1044,7 @@ def test_select_package(manifest):
|
||||
"versioned_model.v3",
|
||||
"versioned_model.v4",
|
||||
"table_model",
|
||||
"table_model_prql",
|
||||
"table_model_py",
|
||||
"table_model_csv",
|
||||
"view_model",
|
||||
|
||||
@@ -991,8 +991,72 @@ class ModelParserTest(BaseParserTest):
|
||||
node = list(self.parser.manifest.nodes.values())[0]
|
||||
self.assertEqual(node.get_materialization(), "table")
|
||||
|
||||
def test_python_model_custom_materialization(self):
|
||||
block = self.file_block_for(python_model_custom_materialization, "nested/py_model.py")
|
||||
def test_parse_error(self):
|
||||
block = self.file_block_for("{{ SYNTAX ERROR }}", "nested/model_1.sql")
|
||||
with self.assertRaises(CompilationError):
|
||||
self.parser.parse_file(block)
|
||||
|
||||
def test_parse_prql_file(self):
|
||||
prql_code = """
|
||||
from (dbt source.salesforce.in_process)
|
||||
join (dbt ref.foo.bar) [id]
|
||||
filter salary > 100
|
||||
""".strip()
|
||||
block = self.file_block_for(prql_code, "nested/prql_model.prql")
|
||||
self.parser.manifest.files[block.file.file_id] = block.file
|
||||
self.parser.parse_file(block)
|
||||
self.assert_has_manifest_lengths(self.parser.manifest, nodes=1)
|
||||
node = list(self.parser.manifest.nodes.values())[0]
|
||||
compiled_sql = """
|
||||
SELECT
|
||||
"{{ source('salesforce', 'in_process') }}".*,
|
||||
"{{ ref('foo', 'bar') }}".*,
|
||||
id
|
||||
FROM
|
||||
{{ source('salesforce', 'in_process') }}
|
||||
JOIN {{ ref('foo', 'bar') }} USING(id)
|
||||
WHERE
|
||||
salary > 100
|
||||
""".strip()
|
||||
expected = ModelNode(
|
||||
alias="prql_model",
|
||||
name="prql_model",
|
||||
database="test",
|
||||
schema="analytics",
|
||||
resource_type=NodeType.Model,
|
||||
unique_id="model.snowplow.prql_model",
|
||||
fqn=["snowplow", "nested", "prql_model"],
|
||||
package_name="snowplow",
|
||||
original_file_path=normalize("models/nested/prql_model.prql"),
|
||||
root_path=get_abs_os_path("./dbt_packages/snowplow"),
|
||||
config=NodeConfig(materialized="view"),
|
||||
path=normalize("nested/prql_model.prql"),
|
||||
language="sql", # It's compiled into SQL
|
||||
raw_code=compiled_sql,
|
||||
checksum=block.file.checksum,
|
||||
unrendered_config={"packages": set()},
|
||||
config_call_dict={},
|
||||
refs=[["foo", "bar"], ["foo", "bar"]],
|
||||
sources=[["salesforce", "in_process"]],
|
||||
)
|
||||
assertEqualNodes(node, expected)
|
||||
file_id = "snowplow://" + normalize("models/nested/prql_model.prql")
|
||||
self.assertIn(file_id, self.parser.manifest.files)
|
||||
self.assertEqual(self.parser.manifest.files[file_id].nodes, ["model.snowplow.prql_model"])
|
||||
|
||||
def test_parse_ref_with_non_string(self):
|
||||
py_code = """
|
||||
def model(dbt, session):
|
||||
|
||||
model_names = ["orders", "customers"]
|
||||
models = []
|
||||
|
||||
for model_name in model_names:
|
||||
models.extend(dbt.ref(model_name))
|
||||
|
||||
return models[0]
|
||||
"""
|
||||
block = self.file_block_for(py_code, "nested/py_model.py")
|
||||
self.parser.manifest.files[block.file.file_id] = block.file
|
||||
self.parser.parse_file(block)
|
||||
node = list(self.parser.manifest.nodes.values())[0]
|
||||
|
||||
Reference in New Issue
Block a user