mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-18 22:31:28 +00:00
Compare commits
1 Commits
enable-pos
...
jerco/expe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f13f92d280 |
@@ -171,7 +171,7 @@ def register_adapter(config: AdapterRequiredConfig) -> None:
|
||||
FACTORY.register_adapter(config)
|
||||
|
||||
|
||||
def get_adapter(config: AdapterRequiredConfig):
|
||||
def get_or_create_adapter(config: AdapterRequiredConfig):
|
||||
return FACTORY.lookup_adapter(config.credentials.type)
|
||||
|
||||
|
||||
@@ -217,3 +217,16 @@ def get_adapter_package_names(name: Optional[str]) -> List[str]:
|
||||
|
||||
def get_adapter_type_names(name: Optional[str]) -> List[str]:
|
||||
return FACTORY.get_adapter_type_names(name)
|
||||
|
||||
# this doesn't muck with global adapters
|
||||
# it does use AdapterFactory.plugins (which is global)
|
||||
# to instantiate a new Adapter object, using the AdapterPlugin object
|
||||
# so accessing AdapterFactory.plugins requires a lock
|
||||
# but this method does not actually mutate AdapterFactory
|
||||
def create_adapter(config) -> Adapter:
|
||||
adapter_name = config.credentials.type
|
||||
Plugin = FACTORY.get_adapter_plugins(adapter_name)[0]
|
||||
# TODO is this necessary?
|
||||
import copy
|
||||
adapter = copy.deepcopy(Plugin.adapter)
|
||||
return adapter(config)
|
||||
|
||||
@@ -7,7 +7,6 @@ import pickle
|
||||
import sqlparse
|
||||
|
||||
from dbt import flags
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.clients import jinja
|
||||
from dbt.clients.system import make_directory
|
||||
from dbt.context.providers import generate_runtime_model_context
|
||||
@@ -190,14 +189,14 @@ class Compiler:
|
||||
return context
|
||||
|
||||
def add_ephemeral_prefix(self, name: str):
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
relation_cls = adapter.Relation
|
||||
return relation_cls.add_ephemeral_prefix(name)
|
||||
|
||||
def _get_relation_name(self, node: ParsedNode):
|
||||
relation_name = None
|
||||
if node.is_relational and not node.is_ephemeral_model:
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
relation_cls = adapter.Relation
|
||||
relation_name = str(relation_cls.create_from(self.config, node))
|
||||
return relation_name
|
||||
|
||||
@@ -10,7 +10,7 @@ from .project import Project
|
||||
from .renderer import DbtProjectYamlRenderer, ProfileRenderer
|
||||
from .utils import parse_cli_vars
|
||||
from dbt import flags
|
||||
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
|
||||
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths, create_adapter, Adapter
|
||||
from dbt.helper_types import FQNPath, PathSet, DictDefaultEmptyStr
|
||||
from dbt.config.profile import read_user_config
|
||||
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
|
||||
@@ -47,6 +47,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
profile_name: str
|
||||
cli_vars: Dict[str, Any]
|
||||
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None
|
||||
adapter: Optional[Adapter] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
@@ -373,6 +374,20 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
|
||||
if path.is_dir() and not path.name.startswith("__"):
|
||||
yield path
|
||||
|
||||
@classmethod
|
||||
def get_or_create_adapter(self, config, force_recreate=False) -> Adapter:
|
||||
if self.adapter is None or force_recreate:
|
||||
self.adapter = create_adapter(config)
|
||||
return self.adapter
|
||||
|
||||
|
||||
# mutates this object
|
||||
@classmethod
|
||||
def update_credentials(self, credential_updates: dict):
|
||||
for k,v in credential_updates.items():
|
||||
self.credentials[k] = v
|
||||
config.get_or_create_adapter(config, force=True)
|
||||
|
||||
|
||||
class UnsetCredentials(Credentials):
|
||||
def __init__(self):
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import (
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from dbt.adapters.base.column import Column
|
||||
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
|
||||
from dbt.adapters.factory import get_adapter_package_names, get_adapter_type_names
|
||||
from dbt.clients import agate_helper
|
||||
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
|
||||
from dbt.config import RuntimeConfig, Project
|
||||
@@ -694,7 +694,7 @@ class ProviderContext(ManifestContext):
|
||||
self.sql_results: Dict[str, AttrDict] = {}
|
||||
self.context_config: Optional[ContextConfig] = context_config
|
||||
self.provider: Provider = provider
|
||||
self.adapter = get_adapter(self.config)
|
||||
self.adapter = self.config.get_or_create_adapter(self.config)
|
||||
# The macro namespace is used in creating the DatabaseWrapper
|
||||
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace)
|
||||
|
||||
|
||||
@@ -20,22 +20,33 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):
|
||||
profile = args.profile if hasattr(args, "profile") else None
|
||||
target = args.target if hasattr(args, "target") else None
|
||||
|
||||
# Construct a phony config
|
||||
config = RuntimeConfig.from_args(
|
||||
# TODO: do this without reading files
|
||||
config = RuntimeConfig.from_parts(
|
||||
RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target)
|
||||
)
|
||||
# Clear previously registered adapters--
|
||||
# this fixes cacheing behavior on the dbt-server
|
||||
|
||||
# we just want the side-effect here of caching the adapter on the RuntimeConfig
|
||||
adapter = config.get_or_create_adapter(config)
|
||||
|
||||
# TODO use new cli.flags.Flags here to de-globalize
|
||||
flags.set_from_args(args, config)
|
||||
dbt.adapters.factory.reset_adapters()
|
||||
# Load the relevant adapter
|
||||
dbt.adapters.factory.register_adapter(config)
|
||||
|
||||
# Set invocation id
|
||||
# TODO: make this not global. We could store this in Flags or RuntimeConfig instead
|
||||
dbt.events.functions.set_invocation_id()
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# this has side effects, but only for this instance of RuntimeConfig
|
||||
# zero global side effects
|
||||
def reload_adapter(config, new_credentials):
|
||||
if new_credentials:
|
||||
config.update_credentials(new_credentials)
|
||||
adapter = config.get_or_create_adapter(config, force_reload=True)
|
||||
return adapter
|
||||
|
||||
|
||||
def get_task_by_type(type):
|
||||
# TODO: we need to tell dbt-server what tasks are available
|
||||
from dbt.task.run import RunTask
|
||||
@@ -82,7 +93,6 @@ def create_task(type, args, manifest, config):
|
||||
def _get_operation_node(manifest, project_path, sql):
|
||||
from dbt.parser.manifest import process_node
|
||||
from dbt.parser.sql import SqlBlockParser
|
||||
import dbt.adapters.factory
|
||||
|
||||
config = get_dbt_config(project_path)
|
||||
block_parser = SqlBlockParser(
|
||||
@@ -91,7 +101,7 @@ def _get_operation_node(manifest, project_path, sql):
|
||||
root_project=config,
|
||||
)
|
||||
|
||||
adapter = dbt.adapters.factory.get_adapter(config)
|
||||
adapter = config.get_or_create_adapter(config)
|
||||
# TODO : This needs a real name?
|
||||
sql_node = block_parser.parse_remote(sql, "name")
|
||||
process_node(config, manifest, sql_node)
|
||||
|
||||
@@ -11,7 +11,6 @@ from dbt.context.providers import (
|
||||
generate_parser_model_context,
|
||||
generate_generate_name_macro_context,
|
||||
)
|
||||
from dbt.adapters.factory import get_adapter # noqa: F401
|
||||
from dbt.clients.jinja import get_rendered
|
||||
from dbt.config import Project, RuntimeConfig
|
||||
from dbt.context.context_config import ContextConfig
|
||||
|
||||
@@ -12,7 +12,6 @@ import dbt.tracking
|
||||
import dbt.flags as flags
|
||||
|
||||
from dbt.adapters.factory import (
|
||||
get_adapter,
|
||||
get_relation_class_by_name,
|
||||
get_adapter_package_names,
|
||||
)
|
||||
@@ -202,9 +201,7 @@ class ManifestLoader:
|
||||
reset: bool = False,
|
||||
) -> Manifest:
|
||||
|
||||
adapter = get_adapter(config) # type: ignore
|
||||
# reset is set in a TaskManager load_manifest call, since
|
||||
# the config and adapter may be persistent.
|
||||
adapter = config.get_or_create_adapter(config) # type: ignore
|
||||
if reset:
|
||||
config.clear_dependencies()
|
||||
adapter.clear_macro_manifest()
|
||||
@@ -520,7 +517,7 @@ class ManifestLoader:
|
||||
def macro_depends_on(self):
|
||||
macro_ctx = generate_macro_context(self.root_project)
|
||||
macro_namespace = TestMacroNamespace(self.macro_resolver, {}, None, MacroStack(), [])
|
||||
adapter = get_adapter(self.root_project)
|
||||
adapter = self.root_project.get_or_create_adapter(self.root_project)
|
||||
db_wrapper = ParseProvider().DatabaseWrapper(adapter, macro_namespace)
|
||||
for macro in self.manifest.macros.values():
|
||||
if macro.created_at < self.started_at:
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar,
|
||||
|
||||
from dbt.dataclass_schema import ValidationError, dbtClassMixin
|
||||
|
||||
from dbt.adapters.factory import get_adapter, get_adapter_package_names
|
||||
from dbt.adapters.factory import get_adapter_package_names
|
||||
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
|
||||
from dbt.clients.yaml_helper import load_yaml_text
|
||||
from dbt.parser.schema_renderer import SchemaYamlRenderer
|
||||
@@ -473,7 +473,7 @@ class SchemaParser(SimpleParser[GenericTestBlock, ParsedGenericTestNode]):
|
||||
column_name = column.name
|
||||
should_quote = column.quote or (column.quote is None and target_block.quote_columns)
|
||||
if should_quote:
|
||||
column_name = get_adapter(self.root_project).quote(column_name)
|
||||
column_name = self.root_project.get_or_create_adapter(self.root_project).quote(column_name)
|
||||
column_tags = column.tags
|
||||
|
||||
block = GenericTestBlock.from_test_block(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Dict, Optional, Set, Any
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.config import RuntimeConfig
|
||||
from dbt.context.context_config import (
|
||||
BaseContextConfigGenerator,
|
||||
@@ -240,7 +239,7 @@ class SourcePatcher:
|
||||
column_name = column.name
|
||||
should_quote = column.quote or (column.quote is None and target.quote_columns)
|
||||
if should_quote:
|
||||
column_name = get_adapter(self.root_project).quote(column_name)
|
||||
column_name = self.root_project.get_or_create_adapter(self.root_project).quote(column_name)
|
||||
|
||||
tags_sources = [target.source.tags, target.table.tags]
|
||||
if column is not None:
|
||||
@@ -285,7 +284,7 @@ class SourcePatcher:
|
||||
)
|
||||
|
||||
def _get_relation_name(self, node: ParsedSourceDefinition):
|
||||
adapter = get_adapter(self.root_project)
|
||||
adapter = get_or_create_adapter(self.root_project)
|
||||
relation_cls = adapter.Relation
|
||||
return str(relation_cls.create_from(self.root_project, node))
|
||||
|
||||
|
||||
@@ -175,10 +175,6 @@ def move_to_nearest_project_dir(args):
|
||||
class ConfiguredTask(BaseTask):
|
||||
ConfigType = RuntimeConfig
|
||||
|
||||
def __init__(self, args, config):
|
||||
super().__init__(args, config)
|
||||
register_adapter(self.config)
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args):
|
||||
move_to_nearest_project_dir(args)
|
||||
|
||||
@@ -3,7 +3,6 @@ from .snapshot import SnapshotRunner as snapshot_model_runner
|
||||
from .seed import SeedRunner as seed_runner
|
||||
from .test import TestRunner as test_runner
|
||||
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.contracts.results import NodeStatus
|
||||
from dbt.exceptions import InternalException
|
||||
from dbt.graph import ResourceTypeSelector
|
||||
@@ -67,6 +66,6 @@ class BuildTask(RunTask):
|
||||
def compile_manifest(self):
|
||||
if self.manifest is None:
|
||||
raise InternalException("compile_manifest called before manifest was loaded")
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
compiler = adapter.get_compiler()
|
||||
self.graph = compiler.compile(self.manifest, add_test_edges=True)
|
||||
|
||||
@@ -9,7 +9,7 @@ from dbt.events.types import OpenCommand
|
||||
from dbt import flags
|
||||
import dbt.clients.system
|
||||
import dbt.exceptions
|
||||
from dbt.adapters.factory import get_adapter, register_adapter
|
||||
from dbt.adapters.factory import create_adapter
|
||||
from dbt.config import Project, Profile
|
||||
from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer
|
||||
from dbt.config.utils import parse_cli_vars
|
||||
@@ -328,8 +328,7 @@ class DebugTask(BaseTask):
|
||||
"""Return a string containing the error message, or None if there was
|
||||
no error.
|
||||
"""
|
||||
register_adapter(profile)
|
||||
adapter = get_adapter(profile)
|
||||
adapter = create_adapter(profile)
|
||||
try:
|
||||
with adapter.connection_named("debug"):
|
||||
adapter.debug_query()
|
||||
|
||||
@@ -7,7 +7,6 @@ from dbt.dataclass_schema import ValidationError
|
||||
|
||||
from .compile import CompileTask
|
||||
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.contracts.graph.compiled import CompileResultNode
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.contracts.results import (
|
||||
@@ -234,7 +233,7 @@ class GenerateTask(CompileTask):
|
||||
if self.manifest is None:
|
||||
raise InternalException("self.manifest was None in run!")
|
||||
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
with adapter.connection_named("generate_catalog"):
|
||||
fire_event(BuildingCatalog())
|
||||
catalog_table, exceptions = adapter.get_catalog(self.manifest)
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
# Use a visualizer such as snakeviz to look at the output:
|
||||
# snakeviz dbt.cprof
|
||||
from dbt.task.base import ConfiguredTask
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.parser.manifest import Manifest, ManifestLoader, _check_manifest
|
||||
from dbt.logger import DbtProcessState
|
||||
from dbt.clients.system import write_file
|
||||
@@ -62,7 +61,7 @@ class ParseTask(ConfiguredTask):
|
||||
# ManifestLoader.load_all
|
||||
|
||||
def get_full_manifest(self):
|
||||
adapter = get_adapter(self.config) # type: ignore
|
||||
adapter = self.config.get_or_create_adapter(self.config) # type: ignore
|
||||
root_config = self.config
|
||||
macro_hook = adapter.connections.set_query_header
|
||||
with PARSING_STATE:
|
||||
@@ -84,7 +83,7 @@ class ParseTask(ConfiguredTask):
|
||||
fire_event(ManifestLoaded())
|
||||
|
||||
def compile_manifest(self):
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
compiler = adapter.get_compiler()
|
||||
self.graph = compiler.compile(self.manifest)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import agate
|
||||
from .runnable import ManifestTask
|
||||
|
||||
import dbt.exceptions
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.config.utils import parse_cli_vars
|
||||
from dbt.contracts.results import RunOperationResultsArtifact
|
||||
from dbt.exceptions import InternalException
|
||||
@@ -36,7 +35,7 @@ class RunOperationTask(ManifestTask):
|
||||
raise InternalException("manifest was None in compile_manifest")
|
||||
|
||||
def _run_unsafe(self) -> agate.Table:
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
|
||||
package_name, macro_name = self._get_macro_parts()
|
||||
macro_kwargs = self._get_kwargs()
|
||||
|
||||
@@ -16,7 +16,6 @@ from .printer import (
|
||||
from dbt.clients.system import write_file
|
||||
from dbt.task.base import ConfiguredTask
|
||||
from dbt.adapters.base import BaseRelation
|
||||
from dbt.adapters.factory import get_adapter
|
||||
from dbt.logger import (
|
||||
DbtProcessState,
|
||||
TextOnly,
|
||||
@@ -86,7 +85,7 @@ class ManifestTask(ConfiguredTask):
|
||||
raise InternalException("compile_manifest called before manifest was loaded")
|
||||
|
||||
# we cannot get adapter in init since it will break rpc #5579
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
compiler = adapter.get_compiler()
|
||||
self.graph = compiler.compile(self.manifest)
|
||||
|
||||
@@ -188,7 +187,7 @@ class GraphRunnableTask(ManifestTask):
|
||||
return os.path.join(self.config.target_path, RESULT_FILE_NAME)
|
||||
|
||||
def get_runner(self, node):
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
run_count: int = 0
|
||||
num_nodes: int = 0
|
||||
|
||||
@@ -350,7 +349,7 @@ class GraphRunnableTask(ManifestTask):
|
||||
pool.close()
|
||||
pool.terminate()
|
||||
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.adapter
|
||||
|
||||
if not adapter.is_cancelable():
|
||||
fire_event(QueryCancelationUnsupported(type=adapter.type()))
|
||||
@@ -427,7 +426,7 @@ class GraphRunnableTask(ManifestTask):
|
||||
pass
|
||||
|
||||
def execute_with_hooks(self, selected_uids: AbstractSet[str]):
|
||||
adapter = get_adapter(self.config)
|
||||
adapter = self.config.get_or_create_adapter(self.config)
|
||||
try:
|
||||
self.before_hooks(adapter)
|
||||
started = time.time()
|
||||
|
||||
6
core/dbt/tests/fixtures/project.py
vendored
6
core/dbt/tests/fixtures/project.py
vendored
@@ -8,7 +8,7 @@ import yaml
|
||||
|
||||
import dbt.flags as flags
|
||||
from dbt.config.runtime import RuntimeConfig
|
||||
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters, get_adapter_by_type
|
||||
from dbt.adapters.factory import get_adapter_by_type
|
||||
from dbt.events.functions import setup_event_logger
|
||||
from dbt.tests.util import (
|
||||
write_file,
|
||||
@@ -246,14 +246,12 @@ def adapter(unique_schema, project_root, profiles_root, profiles_yml, dbt_projec
|
||||
)
|
||||
flags.set_from_args(args, {})
|
||||
runtime_config = RuntimeConfig.from_args(args)
|
||||
register_adapter(runtime_config)
|
||||
adapter = get_adapter(runtime_config)
|
||||
adapter = runtime_config.get_or_create_adapter(runtime_config)
|
||||
# We only need the base macros, not macros from dependencies, and don't want
|
||||
# to run 'dbt deps' here.
|
||||
adapter.load_macro_manifest(base_macros_only=True)
|
||||
yield adapter
|
||||
adapter.cleanup_connections()
|
||||
reset_adapters()
|
||||
|
||||
|
||||
# Start at directory level.
|
||||
|
||||
Reference in New Issue
Block a user