Compare commits

...

3 Commits

Author SHA1 Message Date
Gerda Shank
d629bbf35e Merge branch 'main' into arieldbt/pyodide 2022-09-21 11:59:46 -04:00
Gerda Shank
bad12220fd run pre-commit formatting 2022-09-21 11:58:56 -04:00
Ariel Marcus
4e79436607 Run dbt on WebAssembly using Pyodide 2022-09-09 15:58:28 -04:00
16 changed files with 248 additions and 65 deletions

View File

@@ -0,0 +1,7 @@
kind: Features
body: Run dbt on WebAssembly using Pyodide
time: 2022-09-09T15:47:22.228524-04:00
custom:
Author: arieldbt
Issue: "1970"
PR: "5803"

View File

@@ -3,8 +3,8 @@ import os
from time import sleep
import sys
# multiprocessing.RLock is a function returning this type
from multiprocessing.synchronize import RLock
# dbt.clients.parallel.RLock is a function returning this type
from dbt.clients.parallel import RLock
from threading import get_ident
from typing import (
Any,

107
core/dbt/clients/http.py Normal file
View File

@@ -0,0 +1,107 @@
from dbt import flags
from abc import ABCMeta, abstractmethod
import json
from typing import Any, Dict
import requests
from requests import Response
from urllib.parse import urlencode
class Http(metaclass=ABCMeta):
@abstractmethod
def get_json(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Dict[str, Any]:
raise NotImplementedError
@abstractmethod
def get_response(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError
@abstractmethod
def post(
self,
url: str,
data: Any = None,
headers: Dict[str, str] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError
class PyodideHttp(Http):
def __init__(self) -> None:
super().__init__()
from pyodide.http import open_url
self._open_url = open_url
def get_json(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Dict[str, Any]:
if params is not None:
url += f"?{urlencode(params)}"
r = self._open_url(url=url)
return json.load(r)
def get_response(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError
def post(
self,
url: str,
data: Any = None,
headers: Dict[str, str] = None,
timeout: int = None,
) -> Response:
raise NotImplementedError
class Requests(Http):
def get_json(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Dict[str, Any]:
return self.get_response(url=url, params=params, timeout=timeout).json()
def get_response(
self,
url: str,
params: Dict[str, Any] = None,
timeout: int = None,
) -> Response:
return requests.get(url=url, params=params, timeout=timeout)
def post(
self,
url: str,
data: Any = None,
headers: Dict[str, str] = None,
timeout: int = None,
) -> Response:
return requests.post(url=url, data=data, headers=headers, timeout=timeout)
if flags.IS_PYODIDE:
http = PyodideHttp()
else:
http = Requests()

View File

@@ -0,0 +1,34 @@
from dbt import flags
from threading import Lock as PyodideLock
from threading import RLock as PyodideRLock
if flags.IS_PYODIDE:
pass # multiprocessing doesn't work in pyodide
else:
from multiprocessing.dummy import Pool as MultiprocessingThreadPool
from multiprocessing.synchronize import Lock as MultiprocessingLock
from multiprocessing.synchronize import RLock as MultiprocessingRLock
class PyodideThreadPool:
def __init__(self, num_threads: int) -> None:
pass
def close(self):
pass
def join(self):
pass
def terminate(self):
pass
if flags.IS_PYODIDE:
Lock = PyodideLock
ThreadPool = PyodideThreadPool
RLock = PyodideRLock
else:
Lock = MultiprocessingLock
ThreadPool = MultiprocessingThreadPool
RLock = MultiprocessingRLock

View File

@@ -1,6 +1,7 @@
import functools
from typing import Any, Dict, List
import requests
from dbt.clients.http import http
from dbt.events.functions import fire_event
from dbt.events.types import (
RegistryProgressMakingGETRequest,
@@ -40,7 +41,7 @@ def _get(package_name, registry_base_url=None):
url = _get_url(package_name, registry_base_url)
fire_event(RegistryProgressMakingGETRequest(url=url))
# all exceptions from requests get caught in the retry logic so no need to wrap this here
resp = requests.get(url, timeout=30)
resp = http.get_response(url, timeout=30)
fire_event(RegistryProgressGETResponse(url=url, resp_code=resp.status_code))
resp.raise_for_status()
@@ -164,7 +165,7 @@ def _get_index(registry_base_url=None):
url = _get_url("index", registry_base_url)
fire_event(RegistryIndexProgressMakingGETRequest(url=url))
# all exceptions from requests get caught in the retry logic so no need to wrap this here
resp = requests.get(url, timeout=30)
resp = http.get_response(url, timeout=30)
fire_event(RegistryIndexProgressGETResponse(url=url, resp_code=resp.status_code))
resp.raise_for_status()

View File

@@ -9,7 +9,6 @@ import shutil
import subprocess
import sys
import tarfile
import requests
import stat
from typing import Type, NoReturn, List, Optional, Dict, Any, Tuple, Callable, Union
@@ -22,6 +21,7 @@ from dbt.events.types import (
SystemStdErrMsg,
SystemReportReturnCode,
)
from dbt.clients.http import http
import dbt.exceptions
from dbt.utils import _connection_exception_retry as connection_exception_retry
@@ -451,7 +451,7 @@ def download(
) -> None:
path = convert_path(path)
connection_timeout = timeout or float(os.getenv("DBT_HTTP_TIMEOUT", 10))
response = requests.get(url, timeout=connection_timeout)
response = http.get_response(url, timeout=connection_timeout)
with open(path, "wb") as handle:
for block in response.iter_content(1024 * 64):
handle.write(block)

View File

@@ -2,7 +2,8 @@ import enum
from dataclasses import dataclass, field
from itertools import chain, islice
from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock
from dbt.clients.parallel import Lock
from typing import (
Dict,
List,
@@ -641,10 +642,19 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
default_factory=ParsingInfo,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
_lock: Lock = field(
default_factory=flags.MP_CONTEXT.Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
if flags.IS_PYODIDE:
# Not sure how to avoid this change
# Fails with this error:
# mashumaro.exceptions.UnserializableDataError: <built-in function allocate_lock> as a field type is not supported by mashumaro
_lock: Callable = field(
default_factory=flags.MP_CONTEXT.Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
else:
_lock: Lock = field(
default_factory=flags.MP_CONTEXT.Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
def __pre_serialize__(self):
# serialization won't work with anything except an empty source_patches because

View File

@@ -1,10 +1,6 @@
import os
import multiprocessing
if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore
from pathlib import Path
import sys
from typing import Optional
# PROFILES_DIR must be set before the other flags
@@ -45,6 +41,7 @@ NO_PRINT = None
CACHE_SELECTED_ONLY = None
TARGET_PATH = None
LOG_PATH = None
IS_PYODIDE = "pyodide" in sys.modules # whether dbt is running via pyodide
_NON_BOOLEAN_FLAGS = [
"LOG_FORMAT",
@@ -117,13 +114,25 @@ ARTIFACT_STATE_PATH = env_set_path("DBT_ARTIFACT_STATE_PATH")
ENABLE_LEGACY_LOGGER = env_set_truthy("DBT_ENABLE_LEGACY_LOGGER")
def _get_context():
# TODO: change this back to use fork() on linux when we have made that safe
return multiprocessing.get_context("spawn")
# This is not a flag, it's a place to store the lock
MP_CONTEXT = _get_context()
if IS_PYODIDE:
from typing import NamedTuple
from threading import Lock as PyodideLock
from threading import RLock as PyodideRLock
class PyodideContext(NamedTuple):
Lock = PyodideLock
RLock = PyodideRLock
MP_CONTEXT = PyodideContext()
else:
import multiprocessing
if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore
# TODO: change this back to use fork() on linux when we have made that safe
MP_CONTEXT = multiprocessing.get_context("spawn")
def set_from_args(args, user_config):

View File

@@ -22,7 +22,9 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):
# Construct a phony config
config = RuntimeConfig.from_args(
RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target)
RuntimeArgs(
project_dir, profiles_dir, single_threaded or flags.IS_PYODIDE, profile, target
)
)
# Clear previously registered adapters--
# this fixes cacheing behavior on the dbt-server

View File

@@ -20,12 +20,16 @@ from dbt.parser.search import FileBlock
from dbt.clients.jinja import get_rendered
import dbt.tracking as tracking
from dbt import utils
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
from functools import reduce
from itertools import chain
import random
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
# No support for compiled dependencies on pyodide
if flags.IS_PYODIDE:
pass
else:
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
# New for Python models :p
import ast
from dbt.dataclass_schema import ValidationError
@@ -283,12 +287,15 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
exp_sample_node = deepcopy(node)
exp_sample_config = deepcopy(config)
model_parser_copy.populate(exp_sample_node, exp_sample_config, experimental_sample)
# use the experimental parser exclusively if the flag is on
if flags.USE_EXPERIMENTAL_PARSER:
statically_parsed = self.run_experimental_parser(node)
# run the stable static parser unless it is explicitly turned off
if flags.IS_PYODIDE:
pass
else:
statically_parsed = self.run_static_parser(node)
# use the experimental parser exclusively if the flag is on
if flags.USE_EXPERIMENTAL_PARSER:
statically_parsed = self.run_experimental_parser(node)
# run the stable static parser unless it is explicitly turned off
else:
statically_parsed = self.run_static_parser(node)
# if the static parser succeeded, extract some data in easy-to-compare formats
if isinstance(statically_parsed, dict):

View File

@@ -76,7 +76,7 @@ class BaseTask(metaclass=ABCMeta):
def __init__(self, args, config):
self.args = args
self.args.single_threaded = False
self.args.single_threaded = False or flags.IS_PYODIDE
self.config = config
@classmethod

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from abc import abstractmethod
from concurrent.futures import as_completed
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet
from .printer import (
@@ -13,6 +12,7 @@ from .printer import (
print_run_end_messages,
)
from dbt.clients.parallel import ThreadPool
from dbt.clients.system import write_file
from dbt.task.base import ConfiguredTask
from dbt.adapters.base import BaseRelation
@@ -266,7 +266,7 @@ class GraphRunnableTask(ManifestTask):
This does still go through the callback path for result collection.
"""
if self.config.args.single_threaded:
if self.config.args.single_threaded or flags.IS_PYODIDE:
callback(self.call_runner(*args))
else:
pool.apply_async(self.call_runner, args=args, callback=callback)

View File

@@ -1,5 +1,6 @@
from typing import Optional
from dbt.clients.http import http
from dbt.clients.yaml_helper import ( # noqa:F401
yaml,
safe_load,
@@ -25,7 +26,6 @@ import logbook
import pytz
import platform
import uuid
import requests
import os
sp_logger.setLevel(100)
@@ -81,7 +81,7 @@ class TimeoutEmitter(Emitter):
def http_post(self, payload):
self._log_request("POST", payload)
r = requests.post(
r = http.post(
self.endpoint,
data=payload,
headers={"content-type": "application/json; charset=utf-8"},
@@ -94,7 +94,7 @@ class TimeoutEmitter(Emitter):
def http_get(self, payload):
self._log_request("GET", payload)
r = requests.get(self.endpoint, params=payload, timeout=5.0)
r = http.get_response(self.endpoint, params=payload, timeout=5.0)
self._log_result("GET", r.status_code)
return r
@@ -257,7 +257,7 @@ def get_dbt_env_context():
def track(user, *args, **kwargs):
if user.do_not_track:
if user.do_not_track or flags.IS_PYODIDE:
return
else:
fire_event(SendingEvent(kwargs=str(kwargs)))
@@ -472,7 +472,7 @@ class InvocationProcessor(logbook.Processor):
def initialize_from_flags():
# Setting these used to be in UserConfig, but had to be moved here
if flags.SEND_ANONYMOUS_USAGE_STATS:
initialize_tracking(flags.PROFILES_DIR)
else:
if not flags.SEND_ANONYMOUS_USAGE_STATS or flags.IS_PYODIDE:
do_not_track()
else:
initialize_tracking(flags.PROFILES_DIR)

View File

@@ -531,7 +531,7 @@ class HasThreadingConfig(Protocol):
def executor(config: HasThreadingConfig) -> ConnectingExecutor:
if config.args.single_threaded:
if config.args.single_threaded or flags.IS_PYODIDE:
return SingleThreadedExecutor()
else:
return MultiThreadedExecutor(max_workers=config.threads)

View File

@@ -10,6 +10,7 @@ import requests
import dbt.exceptions
import dbt.semver
from dbt.clients.http import http
from dbt.ui import green, red, yellow
from dbt import flags
@@ -45,8 +46,7 @@ def get_latest_version(
version_url: str = PYPI_VERSION_URL,
) -> Optional[dbt.semver.VersionSpecifier]:
try:
resp = requests.get(version_url)
data = resp.json()
data = http.get_json(version_url)
version_string = data["info"]["version"]
except (json.JSONDecodeError, KeyError, requests.RequestException):
return None

View File

@@ -29,6 +29,34 @@ package_version = "1.3.0b2"
description = """With dbt, data analysts and engineers can build analytics \
the way engineers build applications."""
_install_requires = [
"Jinja2==3.1.2",
"agate>=1.6,<1.6.4",
"click>=7.0,<9",
"colorama>=0.3.9,<0.4.6",
"hologram>=0.0.14,<=0.0.15",
"isodate>=0.6,<0.7",
"logbook>=1.5,<1.6",
"mashumaro[msgpack]==3.0.4",
"minimal-snowplow-tracker==0.0.2",
"networkx>=2.3,<2.8.1;python_version<'3.8'",
"networkx>=2.3,<3;python_version>='3.8'",
"packaging>=20.9,<22.0",
"sqlparse>=0.2.3,<0.5",
"typing-extensions>=3.7.4",
"werkzeug>=1,<3",
# the following are all to match snowflake-connector-python
"requests<3.0.0",
"idna>=2.5,<4",
"cffi>=1.9,<2.0.0",
"pyyaml>=6.0",
]
if "DBT_WASM_BUILD" in os.environ and int(os.environ["DBT_WASM_BUILD"]) == 1:
# binary dependency not supported in pyodide
pass
else:
_install_requires.insert(14, "dbt-extractor~=0.4.1")
setup(
name=package_name,
@@ -45,29 +73,7 @@ setup(
entry_points={
"console_scripts": ["dbt = dbt.main:main"],
},
install_requires=[
"Jinja2==3.1.2",
"agate>=1.6,<1.6.4",
"click>=7.0,<9",
"colorama>=0.3.9,<0.4.6",
"hologram>=0.0.14,<=0.0.15",
"isodate>=0.6,<0.7",
"logbook>=1.5,<1.6",
"mashumaro[msgpack]==3.0.4",
"minimal-snowplow-tracker==0.0.2",
"networkx>=2.3,<2.8.1;python_version<'3.8'",
"networkx>=2.3,<3;python_version>='3.8'",
"packaging>=20.9,<22.0",
"sqlparse>=0.2.3,<0.5",
"dbt-extractor~=0.4.1",
"typing-extensions>=3.7.4",
"werkzeug>=1,<3",
# the following are all to match snowflake-connector-python
"requests<3.0.0",
"idna>=2.5,<4",
"cffi>=1.9,<2.0.0",
"pyyaml>=6.0",
],
install_requires=_install_requires,
zip_safe=False,
classifiers=[
"Development Status :: 5 - Production/Stable",