Compare commits

...

1 Commits

Author SHA1 Message Date
Jeremy Cohen
f13f92d280 Store adapter on RuntimeConfig, rather than global 2022-09-20 13:02:49 +02:00
17 changed files with 72 additions and 52 deletions

View File

@@ -171,7 +171,7 @@ def register_adapter(config: AdapterRequiredConfig) -> None:
FACTORY.register_adapter(config) FACTORY.register_adapter(config)
def get_adapter(config: AdapterRequiredConfig): def get_or_create_adapter(config: AdapterRequiredConfig):
return FACTORY.lookup_adapter(config.credentials.type) return FACTORY.lookup_adapter(config.credentials.type)
@@ -217,3 +217,16 @@ def get_adapter_package_names(name: Optional[str]) -> List[str]:
def get_adapter_type_names(name: Optional[str]) -> List[str]: def get_adapter_type_names(name: Optional[str]) -> List[str]:
return FACTORY.get_adapter_type_names(name) return FACTORY.get_adapter_type_names(name)
# this doesn't muck with global adapters
# it does use AdapterFactory.plugins (which is global)
# to instantiate a new Adapter object, using the AdapterPlugin object
# so accessing AdapterFactory.plugins requires a lock
# but this method does not actually mutate AdapterFactory
def create_adapter(config) -> Adapter:
adapter_name = config.credentials.type
Plugin = FACTORY.get_adapter_plugins(adapter_name)[0]
# TODO is this necessary?
import copy
adapter = copy.deepcopy(Plugin.adapter)
return adapter(config)

View File

@@ -7,7 +7,6 @@ import pickle
import sqlparse import sqlparse
from dbt import flags from dbt import flags
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja from dbt.clients import jinja
from dbt.clients.system import make_directory from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context from dbt.context.providers import generate_runtime_model_context
@@ -190,14 +189,14 @@ class Compiler:
return context return context
def add_ephemeral_prefix(self, name: str): def add_ephemeral_prefix(self, name: str):
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
relation_cls = adapter.Relation relation_cls = adapter.Relation
return relation_cls.add_ephemeral_prefix(name) return relation_cls.add_ephemeral_prefix(name)
def _get_relation_name(self, node: ParsedNode): def _get_relation_name(self, node: ParsedNode):
relation_name = None relation_name = None
if node.is_relational and not node.is_ephemeral_model: if node.is_relational and not node.is_ephemeral_model:
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
relation_cls = adapter.Relation relation_cls = adapter.Relation
relation_name = str(relation_cls.create_from(self.config, node)) relation_name = str(relation_cls.create_from(self.config, node))
return relation_name return relation_name

View File

@@ -10,7 +10,7 @@ from .project import Project
from .renderer import DbtProjectYamlRenderer, ProfileRenderer from .renderer import DbtProjectYamlRenderer, ProfileRenderer
from .utils import parse_cli_vars from .utils import parse_cli_vars
from dbt import flags from dbt import flags
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths from dbt.adapters.factory import get_relation_class_by_name, get_include_paths, create_adapter, Adapter
from dbt.helper_types import FQNPath, PathSet, DictDefaultEmptyStr from dbt.helper_types import FQNPath, PathSet, DictDefaultEmptyStr
from dbt.config.profile import read_user_config from dbt.config.profile import read_user_config
from dbt.contracts.connection import AdapterRequiredConfig, Credentials from dbt.contracts.connection import AdapterRequiredConfig, Credentials
@@ -47,6 +47,7 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
profile_name: str profile_name: str
cli_vars: Dict[str, Any] cli_vars: Dict[str, Any]
dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None dependencies: Optional[Mapping[str, "RuntimeConfig"]] = None
adapter: Optional[Adapter] = None
def __post_init__(self): def __post_init__(self):
self.validate() self.validate()
@@ -373,6 +374,20 @@ class RuntimeConfig(Project, Profile, AdapterRequiredConfig):
if path.is_dir() and not path.name.startswith("__"): if path.is_dir() and not path.name.startswith("__"):
yield path yield path
@classmethod
def get_or_create_adapter(self, config, force_recreate=False) -> Adapter:
if self.adapter is None or force_recreate:
self.adapter = create_adapter(config)
return self.adapter
# mutates this object
@classmethod
def update_credentials(self, credential_updates: dict):
for k,v in credential_updates.items():
self.credentials[k] = v
config.get_or_create_adapter(config, force=True)
class UnsetCredentials(Credentials): class UnsetCredentials(Credentials):
def __init__(self): def __init__(self):

View File

@@ -15,7 +15,7 @@ from typing import (
from typing_extensions import Protocol from typing_extensions import Protocol
from dbt.adapters.base.column import Column from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names from dbt.adapters.factory import get_adapter_package_names, get_adapter_type_names
from dbt.clients import agate_helper from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.config import RuntimeConfig, Project from dbt.config import RuntimeConfig, Project
@@ -694,7 +694,7 @@ class ProviderContext(ManifestContext):
self.sql_results: Dict[str, AttrDict] = {} self.sql_results: Dict[str, AttrDict] = {}
self.context_config: Optional[ContextConfig] = context_config self.context_config: Optional[ContextConfig] = context_config
self.provider: Provider = provider self.provider: Provider = provider
self.adapter = get_adapter(self.config) self.adapter = self.config.get_or_create_adapter(self.config)
# The macro namespace is used in creating the DatabaseWrapper # The macro namespace is used in creating the DatabaseWrapper
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace) self.db_wrapper = self.provider.DatabaseWrapper(self.adapter, self.namespace)

View File

@@ -20,22 +20,33 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):
profile = args.profile if hasattr(args, "profile") else None profile = args.profile if hasattr(args, "profile") else None
target = args.target if hasattr(args, "target") else None target = args.target if hasattr(args, "target") else None
# Construct a phony config # TODO: do this without reading files
config = RuntimeConfig.from_args( config = RuntimeConfig.from_parts(
RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target) RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target)
) )
# Clear previously registered adapters--
# this fixes cacheing behavior on the dbt-server # we just want the side-effect here of caching the adapter on the RuntimeConfig
adapter = config.get_or_create_adapter(config)
# TODO use new cli.flags.Flags here to de-globalize
flags.set_from_args(args, config) flags.set_from_args(args, config)
dbt.adapters.factory.reset_adapters()
# Load the relevant adapter
dbt.adapters.factory.register_adapter(config)
# Set invocation id # Set invocation id
# TODO: make this not global. We could store this in Flags or RuntimeConfig instead
dbt.events.functions.set_invocation_id() dbt.events.functions.set_invocation_id()
return config return config
# this has side effects, but only for this instance of RuntimeConfig
# zero global side effects
def reload_adapter(config, new_credentials):
if new_credentials:
config.update_credentials(new_credentials)
adapter = config.get_or_create_adapter(config, force_reload=True)
return adapter
def get_task_by_type(type): def get_task_by_type(type):
# TODO: we need to tell dbt-server what tasks are available # TODO: we need to tell dbt-server what tasks are available
from dbt.task.run import RunTask from dbt.task.run import RunTask
@@ -82,7 +93,6 @@ def create_task(type, args, manifest, config):
def _get_operation_node(manifest, project_path, sql): def _get_operation_node(manifest, project_path, sql):
from dbt.parser.manifest import process_node from dbt.parser.manifest import process_node
from dbt.parser.sql import SqlBlockParser from dbt.parser.sql import SqlBlockParser
import dbt.adapters.factory
config = get_dbt_config(project_path) config = get_dbt_config(project_path)
block_parser = SqlBlockParser( block_parser = SqlBlockParser(
@@ -91,7 +101,7 @@ def _get_operation_node(manifest, project_path, sql):
root_project=config, root_project=config,
) )
adapter = dbt.adapters.factory.get_adapter(config) adapter = config.get_or_create_adapter(config)
# TODO : This needs a real name? # TODO : This needs a real name?
sql_node = block_parser.parse_remote(sql, "name") sql_node = block_parser.parse_remote(sql, "name")
process_node(config, manifest, sql_node) process_node(config, manifest, sql_node)

View File

@@ -11,7 +11,6 @@ from dbt.context.providers import (
generate_parser_model_context, generate_parser_model_context,
generate_generate_name_macro_context, generate_generate_name_macro_context,
) )
from dbt.adapters.factory import get_adapter # noqa: F401
from dbt.clients.jinja import get_rendered from dbt.clients.jinja import get_rendered
from dbt.config import Project, RuntimeConfig from dbt.config import Project, RuntimeConfig
from dbt.context.context_config import ContextConfig from dbt.context.context_config import ContextConfig

View File

@@ -12,7 +12,6 @@ import dbt.tracking
import dbt.flags as flags import dbt.flags as flags
from dbt.adapters.factory import ( from dbt.adapters.factory import (
get_adapter,
get_relation_class_by_name, get_relation_class_by_name,
get_adapter_package_names, get_adapter_package_names,
) )
@@ -202,9 +201,7 @@ class ManifestLoader:
reset: bool = False, reset: bool = False,
) -> Manifest: ) -> Manifest:
adapter = get_adapter(config) # type: ignore adapter = config.get_or_create_adapter(config) # type: ignore
# reset is set in a TaskManager load_manifest call, since
# the config and adapter may be persistent.
if reset: if reset:
config.clear_dependencies() config.clear_dependencies()
adapter.clear_macro_manifest() adapter.clear_macro_manifest()
@@ -520,7 +517,7 @@ class ManifestLoader:
def macro_depends_on(self): def macro_depends_on(self):
macro_ctx = generate_macro_context(self.root_project) macro_ctx = generate_macro_context(self.root_project)
macro_namespace = TestMacroNamespace(self.macro_resolver, {}, None, MacroStack(), []) macro_namespace = TestMacroNamespace(self.macro_resolver, {}, None, MacroStack(), [])
adapter = get_adapter(self.root_project) adapter = self.root_project.get_or_create_adapter(self.root_project)
db_wrapper = ParseProvider().DatabaseWrapper(adapter, macro_namespace) db_wrapper = ParseProvider().DatabaseWrapper(adapter, macro_namespace)
for macro in self.manifest.macros.values(): for macro in self.manifest.macros.values():
if macro.created_at < self.started_at: if macro.created_at < self.started_at:

View File

@@ -8,7 +8,7 @@ from typing import Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar,
from dbt.dataclass_schema import ValidationError, dbtClassMixin from dbt.dataclass_schema import ValidationError, dbtClassMixin
from dbt.adapters.factory import get_adapter, get_adapter_package_names from dbt.adapters.factory import get_adapter_package_names
from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs
from dbt.clients.yaml_helper import load_yaml_text from dbt.clients.yaml_helper import load_yaml_text
from dbt.parser.schema_renderer import SchemaYamlRenderer from dbt.parser.schema_renderer import SchemaYamlRenderer
@@ -473,7 +473,7 @@ class SchemaParser(SimpleParser[GenericTestBlock, ParsedGenericTestNode]):
column_name = column.name column_name = column.name
should_quote = column.quote or (column.quote is None and target_block.quote_columns) should_quote = column.quote or (column.quote is None and target_block.quote_columns)
if should_quote: if should_quote:
column_name = get_adapter(self.root_project).quote(column_name) column_name = self.root_project.get_or_create_adapter(self.root_project).quote(column_name)
column_tags = column.tags column_tags = column.tags
block = GenericTestBlock.from_test_block( block = GenericTestBlock.from_test_block(

View File

@@ -1,7 +1,6 @@
import itertools import itertools
from pathlib import Path from pathlib import Path
from typing import Iterable, Dict, Optional, Set, Any from typing import Iterable, Dict, Optional, Set, Any
from dbt.adapters.factory import get_adapter
from dbt.config import RuntimeConfig from dbt.config import RuntimeConfig
from dbt.context.context_config import ( from dbt.context.context_config import (
BaseContextConfigGenerator, BaseContextConfigGenerator,
@@ -240,7 +239,7 @@ class SourcePatcher:
column_name = column.name column_name = column.name
should_quote = column.quote or (column.quote is None and target.quote_columns) should_quote = column.quote or (column.quote is None and target.quote_columns)
if should_quote: if should_quote:
column_name = get_adapter(self.root_project).quote(column_name) column_name = self.root_project.get_or_create_adapter(self.root_project).quote(column_name)
tags_sources = [target.source.tags, target.table.tags] tags_sources = [target.source.tags, target.table.tags]
if column is not None: if column is not None:
@@ -285,7 +284,7 @@ class SourcePatcher:
) )
def _get_relation_name(self, node: ParsedSourceDefinition): def _get_relation_name(self, node: ParsedSourceDefinition):
adapter = get_adapter(self.root_project) adapter = get_or_create_adapter(self.root_project)
relation_cls = adapter.Relation relation_cls = adapter.Relation
return str(relation_cls.create_from(self.root_project, node)) return str(relation_cls.create_from(self.root_project, node))

View File

@@ -175,10 +175,6 @@ def move_to_nearest_project_dir(args):
class ConfiguredTask(BaseTask): class ConfiguredTask(BaseTask):
ConfigType = RuntimeConfig ConfigType = RuntimeConfig
def __init__(self, args, config):
super().__init__(args, config)
register_adapter(self.config)
@classmethod @classmethod
def from_args(cls, args): def from_args(cls, args):
move_to_nearest_project_dir(args) move_to_nearest_project_dir(args)

View File

@@ -3,7 +3,6 @@ from .snapshot import SnapshotRunner as snapshot_model_runner
from .seed import SeedRunner as seed_runner from .seed import SeedRunner as seed_runner
from .test import TestRunner as test_runner from .test import TestRunner as test_runner
from dbt.adapters.factory import get_adapter
from dbt.contracts.results import NodeStatus from dbt.contracts.results import NodeStatus
from dbt.exceptions import InternalException from dbt.exceptions import InternalException
from dbt.graph import ResourceTypeSelector from dbt.graph import ResourceTypeSelector
@@ -67,6 +66,6 @@ class BuildTask(RunTask):
def compile_manifest(self): def compile_manifest(self):
if self.manifest is None: if self.manifest is None:
raise InternalException("compile_manifest called before manifest was loaded") raise InternalException("compile_manifest called before manifest was loaded")
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
compiler = adapter.get_compiler() compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest, add_test_edges=True) self.graph = compiler.compile(self.manifest, add_test_edges=True)

View File

@@ -9,7 +9,7 @@ from dbt.events.types import OpenCommand
from dbt import flags from dbt import flags
import dbt.clients.system import dbt.clients.system
import dbt.exceptions import dbt.exceptions
from dbt.adapters.factory import get_adapter, register_adapter from dbt.adapters.factory import create_adapter
from dbt.config import Project, Profile from dbt.config import Project, Profile
from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer
from dbt.config.utils import parse_cli_vars from dbt.config.utils import parse_cli_vars
@@ -328,8 +328,7 @@ class DebugTask(BaseTask):
"""Return a string containing the error message, or None if there was """Return a string containing the error message, or None if there was
no error. no error.
""" """
register_adapter(profile) adapter = create_adapter(profile)
adapter = get_adapter(profile)
try: try:
with adapter.connection_named("debug"): with adapter.connection_named("debug"):
adapter.debug_query() adapter.debug_query()

View File

@@ -7,7 +7,6 @@ from dbt.dataclass_schema import ValidationError
from .compile import CompileTask from .compile import CompileTask
from dbt.adapters.factory import get_adapter
from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import ( from dbt.contracts.results import (
@@ -234,7 +233,7 @@ class GenerateTask(CompileTask):
if self.manifest is None: if self.manifest is None:
raise InternalException("self.manifest was None in run!") raise InternalException("self.manifest was None in run!")
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
with adapter.connection_named("generate_catalog"): with adapter.connection_named("generate_catalog"):
fire_event(BuildingCatalog()) fire_event(BuildingCatalog())
catalog_table, exceptions = adapter.get_catalog(self.manifest) catalog_table, exceptions = adapter.get_catalog(self.manifest)

View File

@@ -7,7 +7,6 @@
# Use a visualizer such as snakeviz to look at the output: # Use a visualizer such as snakeviz to look at the output:
# snakeviz dbt.cprof # snakeviz dbt.cprof
from dbt.task.base import ConfiguredTask from dbt.task.base import ConfiguredTask
from dbt.adapters.factory import get_adapter
from dbt.parser.manifest import Manifest, ManifestLoader, _check_manifest from dbt.parser.manifest import Manifest, ManifestLoader, _check_manifest
from dbt.logger import DbtProcessState from dbt.logger import DbtProcessState
from dbt.clients.system import write_file from dbt.clients.system import write_file
@@ -62,7 +61,7 @@ class ParseTask(ConfiguredTask):
# ManifestLoader.load_all # ManifestLoader.load_all
def get_full_manifest(self): def get_full_manifest(self):
adapter = get_adapter(self.config) # type: ignore adapter = self.config.get_or_create_adapter(self.config) # type: ignore
root_config = self.config root_config = self.config
macro_hook = adapter.connections.set_query_header macro_hook = adapter.connections.set_query_header
with PARSING_STATE: with PARSING_STATE:
@@ -84,7 +83,7 @@ class ParseTask(ConfiguredTask):
fire_event(ManifestLoaded()) fire_event(ManifestLoaded())
def compile_manifest(self): def compile_manifest(self):
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
compiler = adapter.get_compiler() compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest) self.graph = compiler.compile(self.manifest)

View File

@@ -6,7 +6,6 @@ import agate
from .runnable import ManifestTask from .runnable import ManifestTask
import dbt.exceptions import dbt.exceptions
from dbt.adapters.factory import get_adapter
from dbt.config.utils import parse_cli_vars from dbt.config.utils import parse_cli_vars
from dbt.contracts.results import RunOperationResultsArtifact from dbt.contracts.results import RunOperationResultsArtifact
from dbt.exceptions import InternalException from dbt.exceptions import InternalException
@@ -36,7 +35,7 @@ class RunOperationTask(ManifestTask):
raise InternalException("manifest was None in compile_manifest") raise InternalException("manifest was None in compile_manifest")
def _run_unsafe(self) -> agate.Table: def _run_unsafe(self) -> agate.Table:
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
package_name, macro_name = self._get_macro_parts() package_name, macro_name = self._get_macro_parts()
macro_kwargs = self._get_kwargs() macro_kwargs = self._get_kwargs()

View File

@@ -16,7 +16,6 @@ from .printer import (
from dbt.clients.system import write_file from dbt.clients.system import write_file
from dbt.task.base import ConfiguredTask from dbt.task.base import ConfiguredTask
from dbt.adapters.base import BaseRelation from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.logger import ( from dbt.logger import (
DbtProcessState, DbtProcessState,
TextOnly, TextOnly,
@@ -86,7 +85,7 @@ class ManifestTask(ConfiguredTask):
raise InternalException("compile_manifest called before manifest was loaded") raise InternalException("compile_manifest called before manifest was loaded")
# we cannot get adapter in init since it will break rpc #5579 # we cannot get adapter in init since it will break rpc #5579
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
compiler = adapter.get_compiler() compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest) self.graph = compiler.compile(self.manifest)
@@ -188,7 +187,7 @@ class GraphRunnableTask(ManifestTask):
return os.path.join(self.config.target_path, RESULT_FILE_NAME) return os.path.join(self.config.target_path, RESULT_FILE_NAME)
def get_runner(self, node): def get_runner(self, node):
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
run_count: int = 0 run_count: int = 0
num_nodes: int = 0 num_nodes: int = 0
@@ -350,7 +349,7 @@ class GraphRunnableTask(ManifestTask):
pool.close() pool.close()
pool.terminate() pool.terminate()
adapter = get_adapter(self.config) adapter = self.adapter
if not adapter.is_cancelable(): if not adapter.is_cancelable():
fire_event(QueryCancelationUnsupported(type=adapter.type())) fire_event(QueryCancelationUnsupported(type=adapter.type()))
@@ -427,7 +426,7 @@ class GraphRunnableTask(ManifestTask):
pass pass
def execute_with_hooks(self, selected_uids: AbstractSet[str]): def execute_with_hooks(self, selected_uids: AbstractSet[str]):
adapter = get_adapter(self.config) adapter = self.config.get_or_create_adapter(self.config)
try: try:
self.before_hooks(adapter) self.before_hooks(adapter)
started = time.time() started = time.time()

View File

@@ -8,7 +8,7 @@ import yaml
import dbt.flags as flags import dbt.flags as flags
from dbt.config.runtime import RuntimeConfig from dbt.config.runtime import RuntimeConfig
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters, get_adapter_by_type from dbt.adapters.factory import get_adapter_by_type
from dbt.events.functions import setup_event_logger from dbt.events.functions import setup_event_logger
from dbt.tests.util import ( from dbt.tests.util import (
write_file, write_file,
@@ -246,14 +246,12 @@ def adapter(unique_schema, project_root, profiles_root, profiles_yml, dbt_projec
) )
flags.set_from_args(args, {}) flags.set_from_args(args, {})
runtime_config = RuntimeConfig.from_args(args) runtime_config = RuntimeConfig.from_args(args)
register_adapter(runtime_config) adapter = runtime_config.get_or_create_adapter(runtime_config)
adapter = get_adapter(runtime_config)
# We only need the base macros, not macros from dependencies, and don't want # We only need the base macros, not macros from dependencies, and don't want
# to run 'dbt deps' here. # to run 'dbt deps' here.
adapter.load_macro_manifest(base_macros_only=True) adapter.load_macro_manifest(base_macros_only=True)
yield adapter yield adapter
adapter.cleanup_connections() adapter.cleanup_connections()
reset_adapters()
# Start at directory level. # Start at directory level.