mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-17 19:31:34 +00:00
Compare commits
16 Commits
cl/update_
...
jerco/hack
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cce911842 | ||
|
|
158aa81b0c | ||
|
|
5ddb088049 | ||
|
|
3edc9e53ad | ||
|
|
e0c32f425d | ||
|
|
90223ed279 | ||
|
|
472940423c | ||
|
|
dddb0bff5a | ||
|
|
bc8b65095e | ||
|
|
86eb68f40d | ||
|
|
8eece383ea | ||
|
|
c9572c3106 | ||
|
|
ebff2ceb72 | ||
|
|
5a8fd1e90d | ||
|
|
fa3f17200f | ||
|
|
506f2c939a |
@@ -29,10 +29,12 @@ from dbt.exceptions import (
|
||||
from dbt.graph import Graph
|
||||
from dbt.events.functions import fire_event
|
||||
from dbt.events.types import FoundStats, CompilingNode, WritingInjectedSQLForNode
|
||||
from dbt.node_types import NodeType, ModelLanguage
|
||||
from dbt.node_types import NodeType
|
||||
from dbt.events.format import pluralize
|
||||
import dbt.tracking
|
||||
|
||||
from dbt.parser.languages import get_language_provider_by_name
|
||||
|
||||
graph_file_name = "graph.gpickle"
|
||||
|
||||
|
||||
@@ -363,42 +365,19 @@ class Compiler:
|
||||
{
|
||||
"compiled": False,
|
||||
"compiled_code": None,
|
||||
"compiled_language": None,
|
||||
"extra_ctes_injected": False,
|
||||
"extra_ctes": [],
|
||||
}
|
||||
)
|
||||
compiled_node = _compiled_type_for(node).from_dict(data)
|
||||
|
||||
if compiled_node.language == ModelLanguage.python:
|
||||
# TODO could we also 'minify' this code at all? just aesthetic, not functional
|
||||
|
||||
# quoating seems like something very specific to sql so far
|
||||
# for all python implementations we are seeing there's no quating.
|
||||
# TODO try to find better way to do this, given that
|
||||
original_quoting = self.config.quoting
|
||||
self.config.quoting = {key: False for key in original_quoting.keys()}
|
||||
context = self._create_node_context(compiled_node, manifest, extra_context)
|
||||
|
||||
postfix = jinja.get_rendered(
|
||||
"{{ py_script_postfix(model) }}",
|
||||
context,
|
||||
node,
|
||||
)
|
||||
# we should NOT jinja render the python model's 'raw code'
|
||||
compiled_node.compiled_code = f"{node.raw_code}\n\n{postfix}"
|
||||
# restore quoting settings in the end since context is lazy evaluated
|
||||
self.config.quoting = original_quoting
|
||||
|
||||
else:
|
||||
context = self._create_node_context(compiled_node, manifest, extra_context)
|
||||
compiled_node.compiled_code = jinja.get_rendered(
|
||||
node.raw_code,
|
||||
context,
|
||||
node,
|
||||
)
|
||||
context = self._create_node_context(compiled_node, manifest, extra_context)
|
||||
provider = get_language_provider_by_name(node.language)
|
||||
|
||||
compiled_node.compiled_code = provider.get_compiled_code(node, context)
|
||||
compiled_node.relation_name = self._get_relation_name(node)
|
||||
|
||||
compiled_node.compiled_language = provider.compiled_language()
|
||||
compiled_node.compiled = True
|
||||
|
||||
return compiled_node
|
||||
@@ -514,6 +493,8 @@ class Compiler:
|
||||
fire_event(WritingInjectedSQLForNode(unique_id=node.unique_id))
|
||||
|
||||
if node.compiled_code:
|
||||
# TODO: should compiled_path depend on the compiled_language?
|
||||
# e.g. "model.prql" (source) -> "model.sql" (compiled)
|
||||
node.compiled_path = node.write_node(
|
||||
self.config.target_path, "compiled", node.compiled_code
|
||||
)
|
||||
|
||||
@@ -1316,7 +1316,7 @@ class ModelContext(ProviderContext):
|
||||
# only doing this in sql model for backward compatible
|
||||
if (
|
||||
getattr(self.model, "extra_ctes_injected", None)
|
||||
and self.model.language == ModelLanguage.sql # type: ignore[union-attr]
|
||||
and self.model.compiled_language == ModelLanguage.sql # type: ignore[union-attr]
|
||||
):
|
||||
# TODO CT-211
|
||||
return self.model.compiled_code # type: ignore[union-attr]
|
||||
|
||||
@@ -22,6 +22,7 @@ class ParseFileType(StrEnum):
|
||||
Documentation = "docs"
|
||||
Schema = "schema"
|
||||
Hook = "hook" # not a real filetype, from dbt_project.yml
|
||||
language: str = "sql"
|
||||
|
||||
|
||||
parse_file_type_to_parser = {
|
||||
@@ -194,6 +195,7 @@ class SourceFile(BaseSourceFile):
|
||||
docs: List[str] = field(default_factory=list)
|
||||
macros: List[str] = field(default_factory=list)
|
||||
env_vars: List[str] = field(default_factory=list)
|
||||
language: str = "sql"
|
||||
|
||||
@classmethod
|
||||
def big_seed(cls, path: FilePath) -> "SourceFile":
|
||||
|
||||
@@ -42,6 +42,7 @@ class CompiledNodeMixin(dbtClassMixin):
|
||||
@dataclass
|
||||
class CompiledNode(ParsedNode, CompiledNodeMixin):
|
||||
compiled_code: Optional[str] = None
|
||||
compiled_language: Optional[str] = None # TODO: ModelLanguage
|
||||
extra_ctes_injected: bool = False
|
||||
extra_ctes: List[InjectedCTE] = field(default_factory=list)
|
||||
relation_name: Optional[str] = None
|
||||
|
||||
@@ -78,9 +78,12 @@ class SelectionCriteria:
|
||||
|
||||
@classmethod
|
||||
def default_method(cls, value: str) -> MethodName:
|
||||
from dbt.parser.languages import get_file_extensions
|
||||
|
||||
extensions = tuple(get_file_extensions() + [".csv"])
|
||||
if _probably_path(value):
|
||||
return MethodName.Path
|
||||
elif value.lower().endswith((".sql", ".py", ".csv")):
|
||||
elif value.lower().endswith(extensions):
|
||||
return MethodName.File
|
||||
else:
|
||||
return MethodName.FQN
|
||||
|
||||
@@ -32,7 +32,7 @@ def source(*args, dbt_load_df_function):
|
||||
{%- set config_dict = {} -%}
|
||||
{%- for key in model.config.config_keys_used -%}
|
||||
{# weird type testing with enum, would be much easier to write this logic in Python! #}
|
||||
{%- if key == 'language' -%}
|
||||
{%- if key in ('language', 'compiled_language') -%}
|
||||
{%- set value = 'python' -%}
|
||||
{%- endif -%}
|
||||
{%- set value = model.config[key] -%}
|
||||
|
||||
@@ -66,5 +66,8 @@ class RunHookType(StrEnum):
|
||||
|
||||
|
||||
class ModelLanguage(StrEnum):
|
||||
# TODO: how to make this dynamic?
|
||||
python = "python"
|
||||
sql = "sql"
|
||||
ibis = "ibis"
|
||||
prql = "prql"
|
||||
|
||||
@@ -23,6 +23,8 @@ from dbt import hooks
|
||||
from dbt.node_types import NodeType, ModelLanguage
|
||||
from dbt.parser.search import FileBlock
|
||||
|
||||
from dbt.parser.languages import get_language_providers, get_language_provider_by_name
|
||||
|
||||
# internally, the parser may store a less-restrictive type that will be
|
||||
# transformed into the final type. But it will have to be derived from
|
||||
# ParsedNode to be operable.
|
||||
@@ -157,7 +159,7 @@ class ConfiguredParser(
|
||||
config[key] = [hooks.get_hook_dict(h) for h in config[key]]
|
||||
|
||||
def _create_error_node(
|
||||
self, name: str, path: str, original_file_path: str, raw_code: str, language: str = "sql"
|
||||
self, name: str, path: str, original_file_path: str, raw_code: str, language: str
|
||||
) -> UnparsedNode:
|
||||
"""If we hit an error before we've actually parsed a node, provide some
|
||||
level of useful information by attaching this to the exception.
|
||||
@@ -189,11 +191,14 @@ class ConfiguredParser(
|
||||
"""
|
||||
if name is None:
|
||||
name = block.name
|
||||
if block.path.relative_path.endswith(".py"):
|
||||
language = ModelLanguage.python
|
||||
else:
|
||||
# this is not ideal but we have a lot of tests to adjust if don't do it
|
||||
language = ModelLanguage.sql
|
||||
|
||||
# this is pretty silly, but we need "sql" to be the default
|
||||
# even for seeds etc (.csv)
|
||||
# otherwise this breaks a lot of tests
|
||||
language = ModelLanguage.sql
|
||||
for provider in get_language_providers():
|
||||
if block.path.relative_path.endswith(provider.file_ext()):
|
||||
language = ModelLanguage[provider.name()]
|
||||
|
||||
dct = {
|
||||
"alias": name,
|
||||
@@ -223,23 +228,13 @@ class ConfiguredParser(
|
||||
path=path,
|
||||
original_file_path=block.path.original_file_path,
|
||||
raw_code=block.contents,
|
||||
language=language,
|
||||
)
|
||||
raise ParsingException(msg, node=node)
|
||||
|
||||
def _context_for(self, parsed_node: IntermediateNode, config: ContextConfig) -> Dict[str, Any]:
|
||||
return generate_parser_model_context(parsed_node, self.root_project, self.manifest, config)
|
||||
|
||||
def render_with_context(self, parsed_node: IntermediateNode, config: ContextConfig):
|
||||
# Given the parsed node and a ContextConfig to use during parsing,
|
||||
# render the node's sql with macro capture enabled.
|
||||
# Note: this mutates the config object when config calls are rendered.
|
||||
context = self._context_for(parsed_node, config)
|
||||
|
||||
# this goes through the process of rendering, but just throws away
|
||||
# the rendered result. The "macro capture" is the point?
|
||||
get_rendered(parsed_node.raw_code, context, parsed_node, capture_macros=True)
|
||||
return context
|
||||
|
||||
# This is taking the original config for the node, converting it to a dict,
|
||||
# updating the config with new config passed in, then re-creating the
|
||||
# config from the dict in the node.
|
||||
@@ -358,7 +353,10 @@ class ConfiguredParser(
|
||||
|
||||
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
|
||||
try:
|
||||
context = self.render_with_context(node, config)
|
||||
provider = get_language_provider_by_name(node.language)
|
||||
provider.validate_raw_code(node)
|
||||
context = self._context_for(node, config)
|
||||
context = provider.update_context(node, config, context)
|
||||
self.update_parsed_node_config(node, config, context=context)
|
||||
except ValidationError as exc:
|
||||
# we got a ValidationError - probably bad types in config()
|
||||
@@ -405,6 +403,18 @@ class SimpleParser(
|
||||
return node
|
||||
|
||||
|
||||
# TODO: rename these to be more generic (not just SQL)
|
||||
# The full inheritance order for models is:
|
||||
# dbt.parser.models.ModelParser,
|
||||
# dbt.parser.base.SimpleSQLParser,
|
||||
# dbt.parser.base.SQLParser,
|
||||
# dbt.parser.base.ConfiguredParser,
|
||||
# dbt.parser.base.Parser,
|
||||
# dbt.parser.base.BaseParser,
|
||||
# These fine-grained class distinctions exist to support other parsers
|
||||
# e.g. SnapshotParser overrides both 'parse_file' + 'transform'
|
||||
|
||||
|
||||
class SQLParser(
|
||||
ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode]
|
||||
):
|
||||
|
||||
25
core/dbt/parser/languages/__init__.py
Normal file
25
core/dbt/parser/languages/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .provider import LanguageProvider # noqa
|
||||
from .jinja_sql import JinjaSQLProvider # noqa
|
||||
from .python import PythonProvider # noqa
|
||||
|
||||
# TODO: how to make this discovery/registration pluggable?
|
||||
from .prql import PrqlProvider # noqa
|
||||
from .ibis import IbisProvider # noqa
|
||||
|
||||
|
||||
def get_language_providers():
|
||||
return LanguageProvider.__subclasses__()
|
||||
|
||||
|
||||
def get_language_names():
|
||||
return [provider.name() for provider in get_language_providers()]
|
||||
|
||||
|
||||
def get_file_extensions():
|
||||
return [provider.file_ext() for provider in get_language_providers()]
|
||||
|
||||
|
||||
def get_language_provider_by_name(language_name: str) -> LanguageProvider:
|
||||
return next(
|
||||
iter(provider for provider in get_language_providers() if provider.name() == language_name)
|
||||
)
|
||||
116
core/dbt/parser/languages/ibis.py
Normal file
116
core/dbt/parser/languages/ibis.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import ibis
|
||||
import ast
|
||||
|
||||
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
|
||||
from dbt.parser.languages.python import PythonParseVisitor
|
||||
from dbt.contracts.graph.compiled import ManifestNode
|
||||
|
||||
from dbt.exceptions import ParsingException, validator_error_message
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class IbisProvider(LanguageProvider):
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "ibis"
|
||||
|
||||
@classmethod
|
||||
def file_ext(self) -> str:
|
||||
return ".ibis"
|
||||
|
||||
@classmethod
|
||||
def compiled_language(self) -> str:
|
||||
return "sql"
|
||||
|
||||
@classmethod
|
||||
def validate_raw_code(self, node) -> None:
|
||||
# don't require the 'model' function for now
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def extract_dbt_function_calls(self, node: Any) -> dbt_function_calls:
|
||||
"""
|
||||
List all references (refs, sources, configs) in a given block.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(node.raw_code, filename=node.original_file_path)
|
||||
except SyntaxError as exc:
|
||||
msg = validator_error_message(exc)
|
||||
raise ParsingException(f"{msg}\n{exc.text}", node=node) from exc
|
||||
|
||||
# don't worry about the 'model' function for now
|
||||
# dbtValidator = PythonValidationVisitor()
|
||||
# dbtValidator.visit(tree)
|
||||
# dbtValidator.check_error(node)
|
||||
|
||||
dbtParser = PythonParseVisitor(node)
|
||||
dbtParser.visit(tree)
|
||||
return dbtParser.dbt_function_calls
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
# TODO: this is technically true, but Ibis won't actually use dbt's connection, it will make its own
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
resolved_references = self.get_resolved_references(node, context)
|
||||
|
||||
def ref(*args, dbt_load_df_function):
|
||||
refs = resolved_references["refs"]
|
||||
key = tuple(args)
|
||||
return dbt_load_df_function(refs[key])
|
||||
|
||||
def source(*args, dbt_load_df_function):
|
||||
sources = resolved_references["sources"]
|
||||
key = tuple(args)
|
||||
return dbt_load_df_function(sources[key])
|
||||
|
||||
config_dict = {}
|
||||
for key in node.config.get("config_keys_used", []):
|
||||
value = node.config[key]
|
||||
config_dict.update({key: value})
|
||||
|
||||
class config:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get(key, default=None):
|
||||
return config_dict.get(key, default)
|
||||
|
||||
class this:
|
||||
"""dbt.this() or dbt.this.identifier"""
|
||||
|
||||
database = node.database
|
||||
schema = node.schema
|
||||
identifier = node.identifier
|
||||
|
||||
def __repr__(self):
|
||||
return node.relation_name
|
||||
|
||||
class dbtObj:
|
||||
def __init__(self, load_df_function) -> None:
|
||||
self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function)
|
||||
self.ref = lambda *args: ref(*args, dbt_load_df_function=load_df_function)
|
||||
self.config = config
|
||||
self.this = this()
|
||||
# self.is_incremental = TODO
|
||||
|
||||
# https://ibis-project.org/docs/dev/backends/PostgreSQL/#ibis.backends.postgres.Backend.do_connect
|
||||
# TODO: this would need to live in the adapter somehow
|
||||
target = context["target"]
|
||||
con = ibis.postgres.connect(
|
||||
database=target["database"],
|
||||
user=target["user"],
|
||||
)
|
||||
|
||||
# use for dbt.ref(), dbt.source(), etc
|
||||
dbt = dbtObj(con.table) # noqa
|
||||
|
||||
# TODO: this is unsafe in so many ways
|
||||
exec(node.raw_code)
|
||||
compiled = str(eval(f"ibis.{context['target']['type']}.compile(model)"))
|
||||
|
||||
return compiled
|
||||
34
core/dbt/parser/languages/jinja_sql.py
Normal file
34
core/dbt/parser/languages/jinja_sql.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from dbt.clients import jinja
|
||||
from dbt.context.context_config import ContextConfig
|
||||
from dbt.parser.languages.provider import LanguageProvider
|
||||
from dbt.contracts.graph.compiled import ManifestNode
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class JinjaSQLProvider(LanguageProvider):
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "sql"
|
||||
|
||||
@classmethod
|
||||
def update_context(
|
||||
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# this goes through the process of rendering, but we don't keep the rendered result
|
||||
# the goal is to capture macros + update context as side effect
|
||||
jinja.get_rendered(node.raw_code, context, node, capture_macros=True)
|
||||
return context
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
compiled_code = jinja.get_rendered(
|
||||
node.raw_code,
|
||||
context,
|
||||
node,
|
||||
)
|
||||
return compiled_code
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
return True
|
||||
97
core/dbt/parser/languages/provider.py
Normal file
97
core/dbt/parser/languages/provider.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Tuple, List, Any
|
||||
import abc
|
||||
|
||||
# for type hints
|
||||
from dbt.contracts.graph.compiled import ManifestNode
|
||||
from dbt.context.providers import RelationProxy
|
||||
from dbt.context.context_config import ContextConfig
|
||||
|
||||
dbt_function_calls = List[Tuple[str, List[str], Dict[str, Any]]]
|
||||
references_type = Dict[str, Dict[Tuple[str, ...], RelationProxy]]
|
||||
|
||||
|
||||
class LanguageProvider(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
A LanguageProvider is a class that can parse & compile a given language.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def file_ext(self) -> str:
|
||||
return f".{self.name()}"
|
||||
|
||||
@classmethod
|
||||
def compiled_language(self) -> str:
|
||||
return self.name()
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
# TODO add type hints
|
||||
def extract_dbt_function_calls(self, node: Any) -> dbt_function_calls:
|
||||
"""
|
||||
List all dbt function calls (ref, source, config) and their args/kwargs
|
||||
"""
|
||||
raise NotImplementedError("extract_dbt_function_calls")
|
||||
|
||||
@classmethod
|
||||
def validate_raw_code(self, node: Any) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def update_context(
|
||||
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
dbt_function_calls = cls.extract_dbt_function_calls(node)
|
||||
config_keys_used = []
|
||||
for (func, args, kwargs) in dbt_function_calls:
|
||||
if func == "get":
|
||||
config_keys_used.append(args[0])
|
||||
continue
|
||||
|
||||
context[func](*args, **kwargs)
|
||||
if config_keys_used:
|
||||
# this is being used in macro build_config_dict
|
||||
context["config"](config_keys_used=config_keys_used)
|
||||
return context
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
"""
|
||||
Does this modeling language support introspective queries (requiring a database connection)
|
||||
at compile time?
|
||||
"""
|
||||
raise NotImplementedError("needs_compile_time_connection")
|
||||
|
||||
@classmethod
|
||||
def get_resolved_references(
|
||||
self, node: ManifestNode, context: Dict[str, Any]
|
||||
) -> references_type:
|
||||
resolved_references: references_type = {
|
||||
"sources": {},
|
||||
"refs": {},
|
||||
}
|
||||
# TODO: do we need to support custom 'ref' + 'source' resolution logic for non-JinjaSQL languages?
|
||||
# (i.e. user-defined 'ref' + 'source' macros)
|
||||
# this approach will not work for that
|
||||
refs: List[List[str]] = node.refs
|
||||
sources: List[List[str]] = node.sources
|
||||
for ref in refs:
|
||||
resolved_ref: RelationProxy = context["ref"](*ref)
|
||||
resolved_references["refs"].update({tuple(ref): resolved_ref})
|
||||
for source in sources:
|
||||
resolved_src: RelationProxy = context["source"](*source)
|
||||
resolved_references["sources"].update({tuple(source): resolved_src})
|
||||
return resolved_references
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
For a given ManifestNode, return its compiled code.
|
||||
"""
|
||||
raise NotImplementedError("get_compiled_code")
|
||||
170
core/dbt/parser/languages/prql.py
Normal file
170
core/dbt/parser/languages/prql.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
This will be in the `dbt-prql` package, but including here during inital code review, so
|
||||
we can test it without coordinating dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls, references_type
|
||||
|
||||
|
||||
# import prql_python
|
||||
# This mocks the prqlc output for two cases which we currently use in tests, so we can
|
||||
# test this without configuring dependencies. (Obv fix as we expand the tests, way
|
||||
# before we merge.)
|
||||
class prql_python: # type: ignore
|
||||
@staticmethod
|
||||
def to_sql(prql) -> str:
|
||||
|
||||
query_1 = "from employees"
|
||||
|
||||
query_1_compiled = """
|
||||
SELECT
|
||||
employees.*
|
||||
FROM
|
||||
employees
|
||||
""".strip()
|
||||
|
||||
query_2 = """
|
||||
from (dbt source.whatever.some_tbl)
|
||||
join (dbt ref.test.foo) [id]
|
||||
filter salary > 100
|
||||
""".strip()
|
||||
|
||||
# hard coded for Jerco's Postgres database
|
||||
query_2_resolved = """
|
||||
from ("jerco"."salesforce"."in_process")
|
||||
join ("jerco"."dbt_jcohen"."foo") [id]
|
||||
filter salary > 100
|
||||
""".strip()
|
||||
|
||||
query_2_compiled = """
|
||||
SELECT
|
||||
"jerco"."whatever"."some_tbl".*,
|
||||
"jerco"."dbt_jcohen"."foo".*,
|
||||
id
|
||||
FROM
|
||||
"jerco"."salesforce"."in_process"
|
||||
JOIN "jerco"."dbt_jcohen"."foo" USING(id)
|
||||
WHERE
|
||||
salary > 100
|
||||
""".strip()
|
||||
|
||||
lookup = dict(
|
||||
{
|
||||
query_1: query_1_compiled,
|
||||
query_2: query_2_compiled,
|
||||
query_2_resolved: query_2_compiled,
|
||||
}
|
||||
)
|
||||
return lookup[prql]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
word_regex = r"[\w\.\-_]+"
|
||||
# TODO support single-argument form of 'ref'
|
||||
references_regex = rf"\bdbt `?(\w+)\.({word_regex})\.({word_regex})`?"
|
||||
|
||||
|
||||
def hack_compile(prql: str, references: references_type) -> str:
|
||||
"""
|
||||
>>> print(compile(
|
||||
... "from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]",
|
||||
... references=dict(
|
||||
... sources={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'},
|
||||
... refs={('foo', 'bar'): 'foo_schema.bar_tbl'}
|
||||
... )
|
||||
... ))
|
||||
SELECT
|
||||
"{{ source('salesforce', 'in_process') }}".*,
|
||||
"{{ ref('foo', 'bar') }}".*,
|
||||
id
|
||||
FROM
|
||||
{{ source('salesforce', 'in_process') }}
|
||||
JOIN {{ ref('foo', 'bar') }} USING(id)
|
||||
"""
|
||||
|
||||
subs = []
|
||||
for k, v in references["sources"].items():
|
||||
key = ".".join(k)
|
||||
lookup = f"dbt source.{key}"
|
||||
subs.append((lookup, str(v)))
|
||||
|
||||
for k, v in references["refs"].items():
|
||||
key = ".".join(k)
|
||||
lookup = f"dbt ref.{key}"
|
||||
subs.append((lookup, str(v)))
|
||||
|
||||
for lookup, resolved in subs:
|
||||
prql = prql.replace(lookup, resolved)
|
||||
|
||||
sql = prql_python.to_sql(prql)
|
||||
return sql
|
||||
|
||||
|
||||
def hack_list_references(prql):
|
||||
"""
|
||||
List all references (e.g. sources / refs) in a given block.
|
||||
|
||||
We need to decide:
|
||||
|
||||
— What should prqlc return given `dbt source.foo.bar`, so dbt-prql can find the
|
||||
references?
|
||||
— Should it just fill in something that looks like jinja for expediancy? (We
|
||||
don't support jinja though)
|
||||
|
||||
>>> references = list_references("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
|
||||
>>> dict(references)
|
||||
{'source': [('salesforce', 'in_process')], 'ref': [('foo', 'bar')]}
|
||||
"""
|
||||
out = []
|
||||
for t, package, model in _hack_references_of_prql_query(prql):
|
||||
out.append((t, [package, model], {}))
|
||||
return out
|
||||
|
||||
|
||||
def _hack_references_of_prql_query(prql) -> list[tuple[str, str, str]]:
|
||||
"""
|
||||
List the references in a prql query.
|
||||
|
||||
This would be implemented by prqlc.
|
||||
|
||||
>>> _hack_references_of_prql_query("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
|
||||
[('source', 'salesforce', 'in_process'), ('ref', 'foo', 'bar')]
|
||||
"""
|
||||
return re.findall(references_regex, prql)
|
||||
|
||||
|
||||
class PrqlProvider(LanguageProvider):
|
||||
def __init__(self) -> None:
|
||||
# TODO: Uncomment when dbt-prql is released
|
||||
# if not dbt_prql:
|
||||
# raise ImportError(
|
||||
# "dbt_prql is required and not found; try running `pip install dbt_prql`"
|
||||
# )
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "prql"
|
||||
|
||||
@classmethod
|
||||
def compiled_language(self) -> str:
|
||||
return "sql"
|
||||
|
||||
@classmethod
|
||||
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
|
||||
return hack_list_references(node.raw_code)
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node, context) -> str:
|
||||
resolved_references = self.get_resolved_references(node, context)
|
||||
return hack_compile(node.raw_code, references=resolved_references)
|
||||
195
core/dbt/parser/languages/python.py
Normal file
195
core/dbt/parser/languages/python.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import ast
|
||||
|
||||
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
|
||||
from dbt.exceptions import UndefinedMacroException, ParsingException, validator_error_message
|
||||
from dbt.contracts.graph.compiled import ManifestNode
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
dbt_function_key_words = set(["ref", "source", "config", "get"])
|
||||
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
|
||||
|
||||
|
||||
class PythonValidationVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dbt_errors = []
|
||||
self.num_model_def = 0
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if node.name == "model":
|
||||
self.num_model_def += 1
|
||||
if node.args.args and not node.args.args[0].arg == "dbt":
|
||||
self.dbt_errors.append("'dbt' not provided for model as the first argument")
|
||||
if len(node.args.args) != 2:
|
||||
self.dbt_errors.append(
|
||||
"model function should have two args, `dbt` and a session to current warehouse"
|
||||
)
|
||||
# check we have a return and only one
|
||||
if not isinstance(node.body[-1], ast.Return) or isinstance(
|
||||
node.body[-1].value, ast.Tuple
|
||||
):
|
||||
self.dbt_errors.append(
|
||||
"In current version, model function should return only one dataframe object"
|
||||
)
|
||||
|
||||
def check_error(self, node):
|
||||
if self.num_model_def != 1:
|
||||
raise ParsingException("dbt only allow one model defined per python file", node=node)
|
||||
if len(self.dbt_errors) != 0:
|
||||
raise ParsingException("\n".join(self.dbt_errors), node=node)
|
||||
|
||||
|
||||
class PythonParseVisitor(ast.NodeVisitor):
|
||||
def __init__(self, dbt_node):
|
||||
super().__init__()
|
||||
|
||||
self.dbt_node = dbt_node
|
||||
self.dbt_function_calls = []
|
||||
self.packages = []
|
||||
|
||||
@classmethod
|
||||
def _flatten_attr(cls, node):
|
||||
if isinstance(node, ast.Attribute):
|
||||
return str(cls._flatten_attr(node.value)) + "." + node.attr
|
||||
elif isinstance(node, ast.Name):
|
||||
return str(node.id)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _safe_eval(self, node):
|
||||
try:
|
||||
return ast.literal_eval(node)
|
||||
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
|
||||
msg = validator_error_message(
|
||||
f"Error when trying to literal_eval an arg to dbt.ref(), dbt.source(), dbt.config() or dbt.config.get() \n{exc}\n"
|
||||
"https://docs.python.org/3/library/ast.html#ast.literal_eval\n"
|
||||
"In dbt python model, `dbt.ref`, `dbt.source`, `dbt.config`, `dbt.config.get` function args only support Python literal structures"
|
||||
)
|
||||
raise ParsingException(msg, node=self.dbt_node) from exc
|
||||
|
||||
def _get_call_literals(self, node):
|
||||
# List of literals
|
||||
arg_literals = []
|
||||
kwarg_literals = {}
|
||||
|
||||
# TODO : Make sure this throws (and that we catch it)
|
||||
# for non-literal inputs
|
||||
for arg in node.args:
|
||||
rendered = self._safe_eval(arg)
|
||||
arg_literals.append(rendered)
|
||||
|
||||
for keyword in node.keywords:
|
||||
key = keyword.arg
|
||||
rendered = self._safe_eval(keyword.value)
|
||||
kwarg_literals[key] = rendered
|
||||
|
||||
return arg_literals, kwarg_literals
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
# check weather the current call could be a dbt function call
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
|
||||
func_name = self._flatten_attr(node.func)
|
||||
# check weather the current call really is a dbt function call
|
||||
if func_name in dbt_function_full_names:
|
||||
# drop the dot-dbt prefix
|
||||
func_name = func_name.split(".")[-1]
|
||||
args, kwargs = self._get_call_literals(node)
|
||||
self.dbt_function_calls.append((func_name, args, kwargs))
|
||||
|
||||
# no matter what happened above, we should keep visiting the rest of the tree
|
||||
# visit args and kwargs to see if there's call in it
|
||||
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
|
||||
if isinstance(obj, ast.Call):
|
||||
self.visit_Call(obj)
|
||||
# support dbt.ref in list args, kwargs
|
||||
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
|
||||
for el in obj.elts:
|
||||
if isinstance(el, ast.Call):
|
||||
self.visit_Call(el)
|
||||
# support dbt.ref in dict args, kwargs
|
||||
elif isinstance(obj, ast.Dict):
|
||||
for value in obj.values:
|
||||
if isinstance(value, ast.Call):
|
||||
self.visit_Call(value)
|
||||
# visit node.func.value if we are at an call attr
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
self.attribute_helper(node.func)
|
||||
|
||||
def attribute_helper(self, node: ast.Attribute) -> None:
|
||||
while isinstance(node, ast.Attribute):
|
||||
node = node.value # type: ignore
|
||||
if isinstance(node, ast.Call):
|
||||
self.visit_Call(node)
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
for n in node.names:
|
||||
self.packages.append(n.name.split(".")[0])
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if node.module:
|
||||
self.packages.append(node.module.split(".")[0])
|
||||
|
||||
|
||||
class PythonProvider(LanguageProvider):
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "python"
|
||||
|
||||
@classmethod
|
||||
def file_ext(self) -> str:
|
||||
return ".py"
|
||||
|
||||
@classmethod
|
||||
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
|
||||
"""
|
||||
List all references (refs, sources, configs) in a given block.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(node.raw_code, filename=node.original_file_path)
|
||||
except SyntaxError as exc:
|
||||
msg = validator_error_message(exc)
|
||||
raise ParsingException(f"{msg}\n{exc.text}", node=node) from exc
|
||||
|
||||
# We are doing a validator and a parser because visit_FunctionDef in parser
|
||||
# would actually make the parser not doing the visit_Calls any more
|
||||
dbtValidator = PythonValidationVisitor()
|
||||
dbtValidator.visit(tree)
|
||||
dbtValidator.check_error(node)
|
||||
|
||||
dbtParser = PythonParseVisitor(node)
|
||||
dbtParser.visit(tree)
|
||||
return dbtParser.dbt_function_calls
|
||||
|
||||
@classmethod
|
||||
def validate_raw_code(self, node) -> None:
|
||||
from dbt.clients.jinja import get_rendered
|
||||
|
||||
# TODO: add a test for this
|
||||
try:
|
||||
rendered_python = get_rendered(
|
||||
node.raw_code,
|
||||
{},
|
||||
node,
|
||||
)
|
||||
if rendered_python != node.raw_code:
|
||||
raise ParsingException("")
|
||||
except (UndefinedMacroException, ParsingException):
|
||||
raise ParsingException("No jinja in python model code is allowed", node=node)
|
||||
|
||||
@classmethod
|
||||
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
|
||||
# needed for compilation - bad!!
|
||||
from dbt.clients import jinja
|
||||
|
||||
postfix = jinja.get_rendered(
|
||||
"{{ py_script_postfix(model) }}",
|
||||
context,
|
||||
node,
|
||||
)
|
||||
# we should NOT jinja render the python model's 'raw code'
|
||||
return f"{node.raw_code}\n\n{postfix}"
|
||||
|
||||
@classmethod
|
||||
def needs_compile_time_connection(self) -> bool:
|
||||
return False
|
||||
@@ -17,7 +17,6 @@ from dbt.events.types import (
|
||||
from dbt.node_types import NodeType, ModelLanguage
|
||||
from dbt.parser.base import SimpleSQLParser
|
||||
from dbt.parser.search import FileBlock
|
||||
from dbt.clients.jinja import get_rendered
|
||||
import dbt.tracking as tracking
|
||||
from dbt import utils
|
||||
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
|
||||
@@ -26,156 +25,6 @@ from itertools import chain
|
||||
import random
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
# New for Python models :p
|
||||
import ast
|
||||
from dbt.dataclass_schema import ValidationError
|
||||
from dbt.exceptions import ParsingException, validator_error_message, UndefinedMacroException
|
||||
|
||||
|
||||
dbt_function_key_words = set(["ref", "source", "config", "get"])
|
||||
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
|
||||
|
||||
|
||||
class PythonValidationVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dbt_errors = []
|
||||
self.num_model_def = 0
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if node.name == "model":
|
||||
self.num_model_def += 1
|
||||
if node.args.args and not node.args.args[0].arg == "dbt":
|
||||
self.dbt_errors.append("'dbt' not provided for model as the first argument")
|
||||
if len(node.args.args) != 2:
|
||||
self.dbt_errors.append(
|
||||
"model function should have two args, `dbt` and a session to current warehouse"
|
||||
)
|
||||
# check we have a return and only one
|
||||
if not isinstance(node.body[-1], ast.Return) or isinstance(
|
||||
node.body[-1].value, ast.Tuple
|
||||
):
|
||||
self.dbt_errors.append(
|
||||
"In current version, model function should return only one dataframe object"
|
||||
)
|
||||
|
||||
def check_error(self, node):
|
||||
if self.num_model_def != 1:
|
||||
raise ParsingException("dbt only allow one model defined per python file", node=node)
|
||||
if len(self.dbt_errors) != 0:
|
||||
raise ParsingException("\n".join(self.dbt_errors), node=node)
|
||||
|
||||
|
||||
class PythonParseVisitor(ast.NodeVisitor):
|
||||
def __init__(self, dbt_node):
|
||||
super().__init__()
|
||||
|
||||
self.dbt_node = dbt_node
|
||||
self.dbt_function_calls = []
|
||||
self.packages = []
|
||||
|
||||
@classmethod
|
||||
def _flatten_attr(cls, node):
|
||||
if isinstance(node, ast.Attribute):
|
||||
return str(cls._flatten_attr(node.value)) + "." + node.attr
|
||||
elif isinstance(node, ast.Name):
|
||||
return str(node.id)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _safe_eval(self, node):
|
||||
try:
|
||||
return ast.literal_eval(node)
|
||||
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
|
||||
msg = validator_error_message(
|
||||
f"Error when trying to literal_eval an arg to dbt.ref(), dbt.source(), dbt.config() or dbt.config.get() \n{exc}\n"
|
||||
"https://docs.python.org/3/library/ast.html#ast.literal_eval\n"
|
||||
"In dbt python model, `dbt.ref`, `dbt.source`, `dbt.config`, `dbt.config.get` function args only support Python literal structures"
|
||||
)
|
||||
raise ParsingException(msg, node=self.dbt_node) from exc
|
||||
|
||||
def _get_call_literals(self, node):
|
||||
# List of literals
|
||||
arg_literals = []
|
||||
kwarg_literals = {}
|
||||
|
||||
# TODO : Make sure this throws (and that we catch it)
|
||||
# for non-literal inputs
|
||||
for arg in node.args:
|
||||
rendered = self._safe_eval(arg)
|
||||
arg_literals.append(rendered)
|
||||
|
||||
for keyword in node.keywords:
|
||||
key = keyword.arg
|
||||
rendered = self._safe_eval(keyword.value)
|
||||
kwarg_literals[key] = rendered
|
||||
|
||||
return arg_literals, kwarg_literals
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
# check weather the current call could be a dbt function call
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
|
||||
func_name = self._flatten_attr(node.func)
|
||||
# check weather the current call really is a dbt function call
|
||||
if func_name in dbt_function_full_names:
|
||||
# drop the dot-dbt prefix
|
||||
func_name = func_name.split(".")[-1]
|
||||
args, kwargs = self._get_call_literals(node)
|
||||
self.dbt_function_calls.append((func_name, args, kwargs))
|
||||
|
||||
# no matter what happened above, we should keep visiting the rest of the tree
|
||||
# visit args and kwargs to see if there's call in it
|
||||
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
|
||||
if isinstance(obj, ast.Call):
|
||||
self.visit_Call(obj)
|
||||
# support dbt.ref in list args, kwargs
|
||||
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
|
||||
for el in obj.elts:
|
||||
if isinstance(el, ast.Call):
|
||||
self.visit_Call(el)
|
||||
# support dbt.ref in dict args, kwargs
|
||||
elif isinstance(obj, ast.Dict):
|
||||
for value in obj.values:
|
||||
if isinstance(value, ast.Call):
|
||||
self.visit_Call(value)
|
||||
# visit node.func.value if we are at an call attr
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
self.attribute_helper(node.func)
|
||||
|
||||
def attribute_helper(self, node: ast.Attribute) -> None:
|
||||
while isinstance(node, ast.Attribute):
|
||||
node = node.value # type: ignore
|
||||
if isinstance(node, ast.Call):
|
||||
self.visit_Call(node)
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
for n in node.names:
|
||||
self.packages.append(n.name.split(".")[0])
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if node.module:
|
||||
self.packages.append(node.module.split(".")[0])
|
||||
|
||||
|
||||
def merge_packages(original_packages_with_version, new_packages):
|
||||
original_packages = [package.split("==")[0] for package in original_packages_with_version]
|
||||
additional_packages = [package for package in new_packages if package not in original_packages]
|
||||
return original_packages_with_version + list(set(additional_packages))
|
||||
|
||||
|
||||
def verify_python_model_code(node):
|
||||
# TODO: add a test for this
|
||||
try:
|
||||
rendered_python = get_rendered(
|
||||
node.raw_code,
|
||||
{},
|
||||
node,
|
||||
)
|
||||
if rendered_python != node.raw_code:
|
||||
raise ParsingException("")
|
||||
except (UndefinedMacroException, ParsingException):
|
||||
raise ParsingException("No jinja in python model code is allowed", node=node)
|
||||
|
||||
|
||||
class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
||||
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
|
||||
@@ -191,49 +40,16 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
|
||||
def get_compiled_path(cls, block: FileBlock):
|
||||
return block.path.relative_path
|
||||
|
||||
def parse_python_model(self, node, config, context):
|
||||
try:
|
||||
tree = ast.parse(node.raw_code, filename=node.original_file_path)
|
||||
except SyntaxError as exc:
|
||||
msg = validator_error_message(exc)
|
||||
raise ParsingException(f"{msg}\n{exc.text}", node=node) from exc
|
||||
|
||||
# We are doing a validator and a parser because visit_FunctionDef in parser
|
||||
# would actually make the parser not doing the visit_Calls any more
|
||||
dbtValidator = PythonValidationVisitor()
|
||||
dbtValidator.visit(tree)
|
||||
dbtValidator.check_error(node)
|
||||
|
||||
dbtParser = PythonParseVisitor(node)
|
||||
dbtParser.visit(tree)
|
||||
config_keys_used = []
|
||||
for (func, args, kwargs) in dbtParser.dbt_function_calls:
|
||||
if func == "get":
|
||||
config_keys_used.append(args[0])
|
||||
continue
|
||||
|
||||
context[func](*args, **kwargs)
|
||||
if config_keys_used:
|
||||
# this is being used in macro build_config_dict
|
||||
context["config"](config_keys_used=config_keys_used)
|
||||
|
||||
def render_update(self, node: ParsedModelNode, config: ContextConfig) -> None:
|
||||
# TODO
|
||||
if node.language != ModelLanguage.sql:
|
||||
super().render_update(node, config)
|
||||
|
||||
# TODO move all the logic below into JinjaSQL provider
|
||||
|
||||
self.manifest._parsing_info.static_analysis_path_count += 1
|
||||
|
||||
if node.language == ModelLanguage.python:
|
||||
try:
|
||||
verify_python_model_code(node)
|
||||
context = self._context_for(node, config)
|
||||
self.parse_python_model(node, config, context)
|
||||
self.update_parsed_node_config(node, config, context=context)
|
||||
|
||||
except ValidationError as exc:
|
||||
# we got a ValidationError - probably bad types in config()
|
||||
msg = validator_error_message(exc)
|
||||
raise ParsingException(msg, node=node) from exc
|
||||
return
|
||||
|
||||
elif not flags.STATIC_PARSER:
|
||||
if not flags.STATIC_PARSER:
|
||||
# jinja rendering
|
||||
super().render_update(node, config)
|
||||
fire_event(StaticParserCausedJinjaRendering(path=node.path))
|
||||
|
||||
@@ -171,11 +171,15 @@ def read_files(project, files, parser_files, saved_files):
|
||||
dbt_ignore_spec,
|
||||
)
|
||||
|
||||
from dbt.parser.languages import get_file_extensions
|
||||
|
||||
model_extensions = get_file_extensions()
|
||||
|
||||
project_files["ModelParser"] = read_files_for_parser(
|
||||
project,
|
||||
files,
|
||||
project.model_paths,
|
||||
[".sql", ".py"],
|
||||
model_extensions,
|
||||
ParseFileType.Model,
|
||||
saved_files,
|
||||
dbt_ignore_spec,
|
||||
|
||||
@@ -270,6 +270,7 @@ class SchemaParser(SimpleParser[GenericTestBlock, ParsedGenericTestNode]):
|
||||
path=path,
|
||||
original_file_path=target.original_file_path,
|
||||
raw_code=raw_code,
|
||||
language="sql",
|
||||
)
|
||||
raise ParsingException(msg, node=node) from exc
|
||||
|
||||
|
||||
@@ -309,9 +309,16 @@ class BaseRunner(metaclass=ABCMeta):
|
||||
failures=None,
|
||||
)
|
||||
|
||||
# some modeling languages don't need database connections for compilation,
|
||||
# only for runtime (materialization)
|
||||
def needs_connection(self):
|
||||
return True
|
||||
|
||||
def compile_and_execute(self, manifest, ctx):
|
||||
from contextlib import nullcontext
|
||||
|
||||
result = None
|
||||
with self.adapter.connection_for(self.node):
|
||||
with self.adapter.connection_for(self.node) if self.needs_connection() else nullcontext():
|
||||
ctx.node._event_status["node_status"] = RunningStatus.Compiling
|
||||
fire_event(
|
||||
NodeCompiling(
|
||||
|
||||
@@ -20,6 +20,12 @@ class CompileRunner(BaseRunner):
|
||||
def after_execute(self, result):
|
||||
pass
|
||||
|
||||
def needs_connection(self):
|
||||
from dbt.parser.languages import get_language_provider_by_name
|
||||
|
||||
provider = get_language_provider_by_name(self.node.language)
|
||||
return provider.needs_compile_time_connection()
|
||||
|
||||
def execute(self, compiled_node, manifest):
|
||||
return RunResult(
|
||||
node=compiled_node,
|
||||
|
||||
@@ -159,6 +159,9 @@ def _validate_materialization_relations_dict(inp: Dict[Any, Any], model) -> List
|
||||
|
||||
|
||||
class ModelRunner(CompileRunner):
|
||||
def needs_connection(self):
|
||||
return True
|
||||
|
||||
def get_node_representation(self):
|
||||
display_quote_policy = {"database": False, "schema": False, "identifier": False}
|
||||
relation = self.adapter.Relation.create_from(
|
||||
@@ -262,12 +265,12 @@ class ModelRunner(CompileRunner):
|
||||
context_config = context["config"]
|
||||
|
||||
mat_has_supported_langs = hasattr(materialization_macro, "supported_languages")
|
||||
model_lang_supported = model.language in materialization_macro.supported_languages
|
||||
model_lang_supported = model.compiled_language in materialization_macro.supported_languages
|
||||
if mat_has_supported_langs and not model_lang_supported:
|
||||
str_langs = [str(lang) for lang in materialization_macro.supported_languages]
|
||||
raise ValidationException(
|
||||
f'Materialization "{materialization_macro.name}" only supports languages {str_langs}; '
|
||||
f'got "{model.language}"'
|
||||
f'got "{model.language}" which compiles to "{model.compiled_language}"'
|
||||
)
|
||||
|
||||
hook_ctx = self.adapter.pre_model_hook(context_config)
|
||||
|
||||
@@ -60,6 +60,9 @@ class GraphTest(unittest.TestCase):
|
||||
# Create file filesystem searcher
|
||||
self.filesystem_search = patch('dbt.parser.read_files.filesystem_search')
|
||||
def mock_filesystem_search(project, relative_dirs, extension, ignore_spec):
|
||||
# Adding in `and "prql" not in extension` will cause a bunch of tests to
|
||||
# fail; need to understand more on how these are constructed to debug.
|
||||
# Possibly `sql not in extension` is a way of having it only run once.
|
||||
if 'sql' not in extension:
|
||||
return []
|
||||
if 'models' not in relative_dirs:
|
||||
@@ -140,16 +143,16 @@ class GraphTest(unittest.TestCase):
|
||||
return dbt.compilation.Compiler(project)
|
||||
|
||||
def use_models(self, models):
|
||||
for k, v in models.items():
|
||||
for k, (source, lang) in models.items():
|
||||
path = FilePath(
|
||||
searched_path='models',
|
||||
project_root=os.path.normcase(os.getcwd()),
|
||||
relative_path='{}.sql'.format(k),
|
||||
relative_path=f'{k}.{lang}',
|
||||
modification_time=0.0,
|
||||
)
|
||||
# FileHash can't be empty or 'search_key' will be None
|
||||
source_file = SourceFile(path=path, checksum=FileHash.from_contents('abc'))
|
||||
source_file.contents = v
|
||||
source_file = SourceFile(path=path, checksum=FileHash.from_contents('abc'), language=lang)
|
||||
source_file.contents = source
|
||||
self.mock_models.append(source_file)
|
||||
|
||||
def load_manifest(self, config):
|
||||
@@ -162,7 +165,7 @@ class GraphTest(unittest.TestCase):
|
||||
|
||||
def test__single_model(self):
|
||||
self.use_models({
|
||||
'model_one': 'select * from events',
|
||||
'model_one':( 'select * from events', 'sql'),
|
||||
})
|
||||
|
||||
config = self.get_config()
|
||||
@@ -181,8 +184,8 @@ class GraphTest(unittest.TestCase):
|
||||
|
||||
def test__two_models_simple_ref(self):
|
||||
self.use_models({
|
||||
'model_one': 'select * from events',
|
||||
'model_two': "select * from {{ref('model_one')}}",
|
||||
'model_one':( 'select * from events', 'sql'),
|
||||
'model_two':( "select * from {{ref('model_one')}}", 'sql'),
|
||||
})
|
||||
|
||||
config = self.get_config()
|
||||
@@ -205,10 +208,10 @@ class GraphTest(unittest.TestCase):
|
||||
|
||||
def test__model_materializations(self):
|
||||
self.use_models({
|
||||
'model_one': 'select * from events',
|
||||
'model_two': "select * from {{ref('model_one')}}",
|
||||
'model_three': "select * from events",
|
||||
'model_four': "select * from events",
|
||||
'model_one':( 'select * from events', 'sql'),
|
||||
'model_two':( "select * from {{ref('model_one')}}", 'sql'),
|
||||
'model_three':( 'select * from events', 'sql'),
|
||||
'model_four':( 'select * from events', 'sql'),
|
||||
})
|
||||
|
||||
cfg = {
|
||||
@@ -241,7 +244,7 @@ class GraphTest(unittest.TestCase):
|
||||
|
||||
def test__model_incremental(self):
|
||||
self.use_models({
|
||||
'model_one': 'select * from events'
|
||||
'model_one':( 'select * from events', 'sql'),
|
||||
})
|
||||
|
||||
cfg = {
|
||||
@@ -269,15 +272,15 @@ class GraphTest(unittest.TestCase):
|
||||
|
||||
def test__dependency_list(self):
|
||||
self.use_models({
|
||||
'model_1': 'select * from events',
|
||||
'model_2': 'select * from {{ ref("model_1") }}',
|
||||
'model_3': '''
|
||||
'model_1':( 'select * from events', 'sql'),
|
||||
'model_2':( 'select * from {{ ref("model_1") }}', 'sql'),
|
||||
'model_3': ('''
|
||||
select * from {{ ref("model_1") }}
|
||||
union all
|
||||
select * from {{ ref("model_2") }}
|
||||
''',
|
||||
'model_4': 'select * from {{ ref("model_3") }}'
|
||||
})
|
||||
''', "sql"),
|
||||
'model_4':( 'select * from {{ ref("model_3") }}', 'sql'),
|
||||
})
|
||||
|
||||
config = self.get_config()
|
||||
manifest = self.load_manifest(config)
|
||||
@@ -328,3 +331,22 @@ class GraphTest(unittest.TestCase):
|
||||
manifest.metadata.dbt_version = '99999.99.99'
|
||||
is_partial_parsable, _ = loader.is_partial_parsable(manifest)
|
||||
self.assertFalse(is_partial_parsable)
|
||||
|
||||
def test_models_prql(self):
|
||||
self.use_models({
|
||||
'model_prql':( 'from employees', 'prql'),
|
||||
})
|
||||
|
||||
config = self.get_config()
|
||||
manifest = self.load_manifest(config)
|
||||
|
||||
compiler = self.get_compiler(config)
|
||||
linker = compiler.compile(manifest)
|
||||
|
||||
self.assertEqual(
|
||||
list(linker.nodes()),
|
||||
['model.test_models_compile.model_prql'])
|
||||
|
||||
self.assertEqual(
|
||||
list(linker.edges()),
|
||||
[])
|
||||
|
||||
@@ -46,7 +46,8 @@ import dbt.contracts.graph.parsed
|
||||
from .utils import replace_config
|
||||
|
||||
|
||||
def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, alias=None, config_kwargs=None, fqn_extras=None, depends_on_macros=None):
|
||||
# TODO: possibly change `sql` arg to `code`
|
||||
def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, alias=None, config_kwargs=None, fqn_extras=None, depends_on_macros=None, language='sql'):
|
||||
if refs is None:
|
||||
refs = []
|
||||
if sources is None:
|
||||
@@ -54,7 +55,7 @@ def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, al
|
||||
if tags is None:
|
||||
tags = []
|
||||
if path is None:
|
||||
path = f'{name}.sql'
|
||||
path = f'{name}.{language}'
|
||||
if alias is None:
|
||||
alias = name
|
||||
if config_kwargs is None:
|
||||
@@ -78,7 +79,7 @@ def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, al
|
||||
depends_on_nodes.append(src.unique_id)
|
||||
|
||||
return ParsedModelNode(
|
||||
language='sql',
|
||||
language=language,
|
||||
raw_code=sql,
|
||||
database='dbt',
|
||||
schema='dbt_schema',
|
||||
@@ -478,6 +479,18 @@ def table_model(ephemeral_model):
|
||||
path='subdirectory/table_model.sql'
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def table_model_prql(seed):
|
||||
return make_model(
|
||||
'pkg',
|
||||
'table_model_prql',
|
||||
'from (dbt source employees)',
|
||||
config_kwargs={'materialized': 'table'},
|
||||
refs=[seed],
|
||||
tags=[],
|
||||
path='subdirectory/table_model.prql'
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def table_model_py(seed):
|
||||
return make_model(
|
||||
@@ -619,11 +632,11 @@ def namespaced_union_model(seed, ext_source):
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def manifest(seed, source, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, ext_source, ext_model, union_model, ext_source_2,
|
||||
def manifest(seed, source, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, table_model_prql, ext_source, ext_model, union_model, ext_source_2,
|
||||
ext_source_other, ext_source_other_2, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique,
|
||||
view_test_nothing, namespaced_seed, namespace_model, namespaced_union_model, macro_test_unique, macro_default_test_unique,
|
||||
macro_test_not_null, macro_default_test_not_null):
|
||||
nodes = [seed, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, union_model, ext_model,
|
||||
nodes = [seed, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, table_model_prql, union_model, ext_model,
|
||||
table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, view_test_nothing,
|
||||
namespaced_seed, namespace_model, namespaced_union_model]
|
||||
sources = [source, ext_source, ext_source_2,
|
||||
@@ -661,7 +674,7 @@ def test_select_fqn(manifest):
|
||||
assert not search_manifest_using_method(manifest, method, 'ext.unions')
|
||||
# sources don't show up, because selection pretends they have no FQN. Should it?
|
||||
assert search_manifest_using_method(manifest, method, 'pkg') == {
|
||||
'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'view_model', 'ephemeral_model', 'seed',
|
||||
'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'view_model', 'ephemeral_model', 'seed',
|
||||
'mynamespace.union_model', 'mynamespace.ephemeral_model', 'mynamespace.seed'}
|
||||
assert search_manifest_using_method(
|
||||
manifest, method, 'ext') == {'ext_model'}
|
||||
@@ -744,6 +757,8 @@ def test_select_file(manifest):
|
||||
manifest, method, 'table_model.py') == {'table_model_py'}
|
||||
assert search_manifest_using_method(
|
||||
manifest, method, 'table_model.csv') == {'table_model_csv'}
|
||||
assert search_manifest_using_method(
|
||||
manifest, method, 'table_model.prql') == {'table_model_prql'}
|
||||
assert search_manifest_using_method(
|
||||
manifest, method, 'union_model.sql') == {'union_model', 'mynamespace.union_model'}
|
||||
assert not search_manifest_using_method(
|
||||
@@ -758,7 +773,7 @@ def test_select_package(manifest):
|
||||
assert isinstance(method, PackageSelectorMethod)
|
||||
assert method.arguments == []
|
||||
|
||||
assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'view_model', 'ephemeral_model',
|
||||
assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'view_model', 'ephemeral_model',
|
||||
'seed', 'raw.seed', 'unique_table_model_id', 'not_null_table_model_id', 'unique_view_model_id', 'view_test_nothing',
|
||||
'mynamespace.seed', 'mynamespace.ephemeral_model', 'mynamespace.union_model',
|
||||
}
|
||||
@@ -777,7 +792,7 @@ def test_select_config_materialized(manifest):
|
||||
assert search_manifest_using_method(manifest, method, 'view') == {
|
||||
'view_model', 'ext_model'}
|
||||
assert search_manifest_using_method(manifest, method, 'table') == {
|
||||
'table_model', 'table_model_py', 'table_model_csv', 'union_model', 'mynamespace.union_model'}
|
||||
'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'union_model', 'mynamespace.union_model'}
|
||||
|
||||
def test_select_config_meta(manifest):
|
||||
methods = MethodManager(manifest, None)
|
||||
|
||||
@@ -40,6 +40,7 @@ from dbt.parser.models import (
|
||||
import itertools
|
||||
from .utils import config_from_parts_or_dicts, normalize, generate_name_macros, MockNode, MockSource, MockDocumentation
|
||||
|
||||
import dataclasses
|
||||
|
||||
def get_abs_os_path(unix_path):
|
||||
return normalize(os.path.abspath(unix_path))
|
||||
@@ -720,6 +721,54 @@ def model(dbt, session):
|
||||
with self.assertRaises(CompilationException):
|
||||
self.parser.parse_file(block)
|
||||
|
||||
def test_parse_prql_file(self):
|
||||
prql_code = """
|
||||
from (dbt source.salesforce.in_process)
|
||||
join (dbt ref.foo.bar) [id]
|
||||
filter salary > 100
|
||||
""".strip()
|
||||
block = self.file_block_for(prql_code, 'nested/prql_model.prql')
|
||||
self.parser.manifest.files[block.file.file_id] = block.file
|
||||
self.parser.parse_file(block)
|
||||
self.assert_has_manifest_lengths(self.parser.manifest, nodes=1)
|
||||
node = list(self.parser.manifest.nodes.values())[0]
|
||||
compiled_sql = """
|
||||
SELECT
|
||||
"{{ source('salesforce', 'in_process') }}".*,
|
||||
"{{ ref('foo', 'bar') }}".*,
|
||||
id
|
||||
FROM
|
||||
{{ source('salesforce', 'in_process') }}
|
||||
JOIN {{ ref('foo', 'bar') }} USING(id)
|
||||
WHERE
|
||||
salary > 100
|
||||
""".strip()
|
||||
expected = ParsedModelNode(
|
||||
alias='prql_model',
|
||||
name='prql_model',
|
||||
database='test',
|
||||
schema='analytics',
|
||||
resource_type=NodeType.Model,
|
||||
unique_id='model.snowplow.prql_model',
|
||||
fqn=['snowplow', 'nested', 'prql_model'],
|
||||
package_name='snowplow',
|
||||
original_file_path=normalize('models/nested/prql_model.prql'),
|
||||
root_path=get_abs_os_path('./dbt_packages/snowplow'),
|
||||
config=NodeConfig(materialized='view'),
|
||||
path=normalize('nested/prql_model.prql'),
|
||||
language='sql', # It's compiled into SQL
|
||||
raw_code=compiled_sql,
|
||||
checksum=block.file.checksum,
|
||||
unrendered_config={'packages': set()},
|
||||
config_call_dict={},
|
||||
refs=[["foo", "bar"], ["foo", "bar"]],
|
||||
sources=[['salesforce', 'in_process']],
|
||||
)
|
||||
assertEqualNodes(node, expected)
|
||||
file_id = 'snowplow://' + normalize('models/nested/prql_model.prql')
|
||||
self.assertIn(file_id, self.parser.manifest.files)
|
||||
self.assertEqual(self.parser.manifest.files[file_id].nodes, ['model.snowplow.prql_model'])
|
||||
|
||||
def test_parse_ref_with_non_string(self):
|
||||
py_code = """
|
||||
def model(dbt, session):
|
||||
|
||||
Reference in New Issue
Block a user