Compare commits

...

1 Commits

Author SHA1 Message Date
Jeremy Cohen
f13f92d280 Store adapter on RuntimeConfig, rather than global 2022-09-20 13:02:49 +02:00
17 changed files with 72 additions and 52 deletions

View File

@@ -171,7 +171,7 @@ def register_adapter(config: AdapterRequiredConfig) -> None:
FACTORY.register_adapter(config)
def get_adapter(config: AdapterRequiredConfig):
def get_or_create_adapter(config: AdapterRequiredConfig):
return FACTORY.lookup_adapter(config.credentials.type)
@@ -217,3 +217,16 @@ def get_adapter_package_names(name: Optional[str]) -> List[str]:
def get_adapter_type_names(name: Optional[str]) -> List[str]:
return FACTORY.get_adapter_type_names(name)
# this doesn't muck with global adapters
# it does use AdapterFactory.plugins (which is global)
# to instantiate a new Adapter object, using the AdapterPlugin object
# so accessing AdapterFactory.plugins requires a lock
# but this method does not actually mutate AdapterFactory
def create_adapter(config) -> Adapter:
adapter_name = config.credentials.type
Plugin = FACTORY.get_adapter_plugins(adapter_name)[0]
# TODO is this necessary?
import copy
adapter = copy.deepcopy(Plugin.adapter)
return adapter(config)

View File

@@ -7,7 +7,6 @@ import pickle
import sqlparse
from dbt import flags
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context
@@ -190,14 +189,14 @@ class Compiler:
return context
def add_ephemeral_prefix(self, name: str):
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
relation_cls = adapter.Relation
return relation_cls.add_ephemeral_prefix(name)
def _get_relation_name(self, node: ParsedNode):
relation_name = None
if node.is_relational and not node.is_ephemeral_model:
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
relation_cls = adapter.Relation
relation_name = str(relation_cls.create_from(self.config, node))
return relation_name

View File

@@ -10,7 +10,7 @@ from .project import Project
from .renderer import DbtProjectYamlRenderer, ProfileRenderer
from .utils import parse_cli_vars
from dbt import flags
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths, create_adapter, Adapter
from dbt.helper_types import FQNPath, PathSet, DictDefaultEmptyStr
from dbt.config.profile import read_user_config
from dbt.contracts.connection import AdapterRequiredConfig, Credentials
@@ -47,6 +47,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
profile_name: str
cli_vars: Dict[str, Any]
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None
adapter: Optional[Adapter] = None
def __post_init__(self):
self.validate()
@@ -373,6 +374,20 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
if path.is_dir() and not path.name.startswith("__"):
yield path
@classmethod
def get_or_create_adapter(self, config, force_recreate=False) -> Adapter:
if self.adapter is None or force_recreate:
self.adapter = create_adapter(config)
return self.adapter
# mutates this object
@classmethod
def update_credentials(self, credential_updates: dict):
for k,v in credential_updates.items():
self.credentials[k] = v
config.get_or_create_adapter(config, force=True)
class UnsetCredentials(Credentials):
def __init__(self):

View File

@@ -15,7 +15,7 @@ from typing import (
from typing_extensions import Protocol
from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.adapters.factory import get_adapter_package_names, get_adapter_type_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.config import RuntimeConfig, Project
@@ -694,7 +694,7 @@ class ProviderContext(ManifestContext):
self.sql_results: Dict[str, AttrDict] = {}
self.context_config: Optional[ContextConfig] = context_config
self.provider: Provider = provider
self.adapter = get_adapter(self.config)
self.adapter = self.config.get_or_create_adapter(self.config)
# The macro namespace is used in creating the DatabaseWrapper
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace)

View File

@@ -20,22 +20,33 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):
profile = args.profile if hasattr(args, "profile") else None
target = args.target if hasattr(args, "target") else None
# Construct a phony config
config = RuntimeConfig.from_args(
# TODO: do this without reading files
config = RuntimeConfig.from_parts(
RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target)
)
# Clear previously registered adapters--
# this fixes cacheing behavior on the dbt-server
# we just want the side-effect here of caching the adapter on the RuntimeConfig
adapter = config.get_or_create_adapter(config)
# TODO use new cli.flags.Flags here to de-globalize
flags.set_from_args(args, config)
dbt.adapters.factory.reset_adapters()
# Load the relevant adapter
dbt.adapters.factory.register_adapter(config)
# Set invocation id
# TODO: make this not global. We could store this in Flags or RuntimeConfig instead
dbt.events.functions.set_invocation_id()
return config
# this has side effects, but only for this instance of RuntimeConfig
# zero global side effects
def reload_adapter(config, new_credentials):
if new_credentials:
config.update_credentials(new_credentials)
adapter = config.get_or_create_adapter(config, force_reload=True)
return adapter
def get_task_by_type(type):
# TODO: we need to tell dbt-server what tasks are available
from dbt.task.run import RunTask
@@ -82,7 +93,6 @@ def create_task(type, args, manifest, config):
def _get_operation_node(manifest, project_path, sql):
from dbt.parser.manifest import process_node
from dbt.parser.sql import SqlBlockParser
import dbt.adapters.factory
config = get_dbt_config(project_path)
block_parser = SqlBlockParser(
@@ -91,7 +101,7 @@ def _get_operation_node(manifest, project_path, sql):
root_project=config,
)
adapter = dbt.adapters.factory.get_adapter(config)
adapter = config.get_or_create_adapter(config)
# TODO : This needs a real name?
sql_node = block_parser.parse_remote(sql, "name")
process_node(config, manifest, sql_node)

View File

@@ -11,7 +11,6 @@ from dbt.context.providers import (
generate_parser_model_context,
generate_generate_name_macro_context,
)
from dbt.adapters.factory import get_adapter # noqa: F401
from dbt.clients.jinja import get_rendered
from dbt.config import Project, RuntimeConfig
from dbt.context.context_config import ContextConfig

View File

@@ -12,7 +12,6 @@ import dbt.tracking
import dbt.flags as flags
from dbt.adapters.factory import (
get_adapter,
get_relation_class_by_name,
get_adapter_package_names,
)
@@ -202,9 +201,7 @@ class ManifestLoader:
reset: bool = False,
) -> Manifest:
adapter = get_adapter(config) # type: ignore
# reset is set in a TaskManager load_manifest call, since
# the config and adapter may be persistent.
adapter = config.get_or_create_adapter(config) # type: ignore
if reset:
config.clear_dependencies()
adapter.clear_macro_manifest()
@@ -520,7 +517,7 @@ class ManifestLoader:
def macro_depends_on(self):
macro_ctx = generate_macro_context(self.root_project)
macro_namespace = TestMacroNamespace(self.macro_resolver, {}, None, MacroStack(), [])
adapter = get_adapter(self.root_project)
adapter = self.root_project.get_or_create_adapter(self.root_project)
db_wrapper = ParseProvider().DatabaseWrapper(adapter, macro_namespace)
for macro in self.manifest.macros.values():
if macro.created_at < self.started_at:

View File

@@ -8,7 +8,7 @@ from typing import Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar,
from dbt.dataclass_schema import ValidationError, dbtClassMixin
from dbt.adapters.factory import get_adapter, get_adapter_package_names
from dbt.adapters.factory import get_adapter_package_names
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
from dbt.clients.yaml_helper import load_yaml_text
from dbt.parser.schema_renderer import SchemaYamlRenderer
@@ -473,7 +473,7 @@ class SchemaParser(SimpleParser[GenericTestBlock, ParsedGenericTestNode]):
column_name = column.name
should_quote = column.quote or (column.quote is None and target_block.quote_columns)
if should_quote:
column_name = get_adapter(self.root_project).quote(column_name)
column_name = self.root_project.get_or_create_adapter(self.root_project).quote(column_name)
column_tags = column.tags
block = GenericTestBlock.from_test_block(

View File

@@ -1,7 +1,6 @@
import itertools
from pathlib import Path
from typing import Iterable, Dict, Optional, Set, Any
from dbt.adapters.factory import get_adapter
from dbt.config import RuntimeConfig
from dbt.context.context_config import (
BaseContextConfigGenerator,
@@ -240,7 +239,7 @@ class SourcePatcher:
column_name = column.name
should_quote = column.quote or (column.quote is None and target.quote_columns)
if should_quote:
column_name = get_adapter(self.root_project).quote(column_name)
column_name = self.root_project.get_or_create_adapter(self.root_project).quote(column_name)
tags_sources = [target.source.tags, target.table.tags]
if column is not None:
@@ -285,7 +284,7 @@ class SourcePatcher:
)
def _get_relation_name(self, node: ParsedSourceDefinition):
adapter = get_adapter(self.root_project)
adapter = get_or_create_adapter(self.root_project)
relation_cls = adapter.Relation
return str(relation_cls.create_from(self.root_project, node))

View File

@@ -175,10 +175,6 @@ def move_to_nearest_project_dir(args):
class ConfiguredTask(BaseTask):
ConfigType = RuntimeConfig
def __init__(self, args, config):
super().__init__(args, config)
register_adapter(self.config)
@classmethod
def from_args(cls, args):
move_to_nearest_project_dir(args)

View File

@@ -3,7 +3,6 @@ from .snapshot import SnapshotRunner as snapshot_model_runner
from .seed import SeedRunner as seed_runner
from .test import TestRunner as test_runner
from dbt.adapters.factory import get_adapter
from dbt.contracts.results import NodeStatus
from dbt.exceptions import InternalException
from dbt.graph import ResourceTypeSelector
@@ -67,6 +66,6 @@ class BuildTask(RunTask):
def compile_manifest(self):
if self.manifest is None:
raise InternalException("compile_manifest called before manifest was loaded")
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest, add_test_edges=True)

View File

@@ -9,7 +9,7 @@ from dbt.events.types import OpenCommand
from dbt import flags
import dbt.clients.system
import dbt.exceptions
from dbt.adapters.factory import get_adapter, register_adapter
from dbt.adapters.factory import create_adapter
from dbt.config import Project, Profile
from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer
from dbt.config.utils import parse_cli_vars
@@ -328,8 +328,7 @@ class DebugTask(BaseTask):
"""Return a string containing the error message, or None if there was
no error.
"""
register_adapter(profile)
adapter = get_adapter(profile)
adapter = create_adapter(profile)
try:
with adapter.connection_named("debug"):
adapter.debug_query()

View File

@@ -7,7 +7,6 @@ from dbt.dataclass_schema import ValidationError
from .compile import CompileTask
from dbt.adapters.factory import get_adapter
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import (
@@ -234,7 +233,7 @@ class GenerateTask(CompileTask):
if self.manifest is None:
raise InternalException("self.manifest was None in run!")
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
with adapter.connection_named("generate_catalog"):
fire_event(BuildingCatalog())
catalog_table, exceptions = adapter.get_catalog(self.manifest)

View File

@@ -7,7 +7,6 @@
# Use a visualizer such as snakeviz to look at the output:
# snakeviz dbt.cprof
from dbt.task.base import ConfiguredTask
from dbt.adapters.factory import get_adapter
from dbt.parser.manifest import Manifest, ManifestLoader, _check_manifest
from dbt.logger import DbtProcessState
from dbt.clients.system import write_file
@@ -62,7 +61,7 @@ class ParseTask(ConfiguredTask):
# ManifestLoader.load_all
def get_full_manifest(self):
adapter = get_adapter(self.config) # type: ignore
adapter = self.config.get_or_create_adapter(self.config) # type: ignore
root_config = self.config
macro_hook = adapter.connections.set_query_header
with PARSING_STATE:
@@ -84,7 +83,7 @@ class ParseTask(ConfiguredTask):
fire_event(ManifestLoaded())
def compile_manifest(self):
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest)

View File

@@ -6,7 +6,6 @@ import agate
from .runnable import ManifestTask
import dbt.exceptions
from dbt.adapters.factory import get_adapter
from dbt.config.utils import parse_cli_vars
from dbt.contracts.results import RunOperationResultsArtifact
from dbt.exceptions import InternalException
@@ -36,7 +35,7 @@ class RunOperationTask(ManifestTask):
raise InternalException("manifest was None in compile_manifest")
def _run_unsafe(self) -> agate.Table:
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
package_name, macro_name = self._get_macro_parts()
macro_kwargs = self._get_kwargs()

View File

@@ -16,7 +16,6 @@ from .printer import (
from dbt.clients.system import write_file
from dbt.task.base import ConfiguredTask
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.logger import (
DbtProcessState,
TextOnly,
@@ -86,7 +85,7 @@ class ManifestTask(ConfiguredTask):
raise InternalException("compile_manifest called before manifest was loaded")
# we cannot get adapter in init since it will break rpc #5579
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest)
@@ -188,7 +187,7 @@ class GraphRunnableTask(ManifestTask):
return os.path.join(self.config.target_path, RESULT_FILE_NAME)
def get_runner(self, node):
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
run_count: int = 0
num_nodes: int = 0
@@ -350,7 +349,7 @@ class GraphRunnableTask(ManifestTask):
pool.close()
pool.terminate()
adapter = get_adapter(self.config)
adapter = self.adapter
if not adapter.is_cancelable():
fire_event(QueryCancelationUnsupported(type=adapter.type()))
@@ -427,7 +426,7 @@ class GraphRunnableTask(ManifestTask):
pass
def execute_with_hooks(self, selected_uids: AbstractSet[str]):
adapter = get_adapter(self.config)
adapter = self.config.get_or_create_adapter(self.config)
try:
self.before_hooks(adapter)
started = time.time()

View File

@@ -8,7 +8,7 @@ import yaml
import dbt.flags as flags
from dbt.config.runtime import RuntimeConfig
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters, get_adapter_by_type
from dbt.adapters.factory import get_adapter_by_type
from dbt.events.functions import setup_event_logger
from dbt.tests.util import (
write_file,
@@ -246,14 +246,12 @@ def adapter(unique_schema, project_root, profiles_root, profiles_yml, dbt_projec
)
flags.set_from_args(args, {})
runtime_config = RuntimeConfig.from_args(args)
register_adapter(runtime_config)
adapter = get_adapter(runtime_config)
adapter = runtime_config.get_or_create_adapter(runtime_config)
# We only need the base macros, not macros from dependencies, and don't want
# to run 'dbt deps' here.
adapter.load_macro_manifest(base_macros_only=True)
yield adapter
adapter.cleanup_connections()
reset_adapters()
# Start at directory level.