Compare commits

...

1 Commits

Author SHA1 Message Date
Jacob Beck
6762538d36 improve the type annotations so we can extract json schemas 2019-10-03 15:25:49 -04:00
2 changed files with 32 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
# never name this package "types", or mypy will crash in ugly ways
from datetime import timedelta
from numbers import Real
from typing import NewType
from hologram import (
@@ -37,7 +38,15 @@ class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
return {'type': 'number'}
class RealEncoder(FieldEncoder):
@property
def json_schema(self):
return {'type': 'number'}
JsonSchemaMixin.register_field_encoders({
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder(),
Real: RealEncoder(),
})

View File

@@ -9,7 +9,10 @@ from hologram import JsonSchemaMixin
from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import extract_toplevel_blocks
from dbt.compilation import compile_manifest
from dbt.contracts.results import RemoteCatalogResults
from dbt.contracts.results import (
RemoteCatalogResults, RemoteCompileResult, RemoteRunResult,
RemoteExecutionResult
)
from dbt.parser.results import ParseResult
from dbt.parser.rpc import RPCCallParser, RPCMacroParser
from dbt.parser.util import ParserUtils
@@ -23,16 +26,20 @@ from dbt.task.run import RunTask
from dbt.task.seed import SeedTask
from dbt.task.test import TestTask
@dataclass
class RPCParameters(JsonSchemaMixin):
tiemout: Optional[Union[float, int]]
@dataclass
class RPCExecParameters(JsonSchemaMixin):
class RPCExecParameters(RPCParameters):
name: str
sql: str
macros: Optional[str]
@dataclass
class RPCCompileProjectParameters(JsonSchemaMixin):
class RPCCompileProjectParameters(RPCParameters):
models: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None
@@ -44,12 +51,12 @@ class RPCTestProjectParameters(RPCCompileProjectParameters):
@dataclass
class RPCSeedProjectParameters(JsonSchemaMixin):
class RPCSeedProjectParameters(RPCParameters):
show: bool = False
@dataclass
class RPCDocsGenerateProjectParameters(JsonSchemaMixin):
class RPCDocsGenerateProjectParameters(RPCParameters):
compile: bool = True
@@ -172,6 +179,9 @@ class _RPCExecTask(RPCTask):
class RemoteCompileTask(_RPCExecTask):
METHOD_NAME = 'compile_sql'
def handle_request(self, params: RPCExecParameters) -> RemoteCompileResult:
return super().handle_request(params)
def get_runner_type(self):
return RPCCompileRunner
@@ -179,6 +189,9 @@ class RemoteCompileTask(_RPCExecTask):
class RemoteRunTask(_RPCExecTask, RunTask):
METHOD_NAME = 'run_sql'
def handle_request(self, params: RPCExecParameters) -> RemoteRunResult:
return super().handle_request(params)
def get_runner_type(self):
return RPCExecuteRunner
@@ -196,7 +209,7 @@ class RemoteCompileProjectTask(RPCTask):
def handle_request(
self, params: RPCCompileProjectParameters
) -> RemoteCallableResult:
) -> RemoteExecutionResult:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
@@ -217,7 +230,7 @@ class RemoteRunProjectTask(RPCTask, RunTask):
def handle_request(
self, params: RPCCompileProjectParameters
) -> RemoteCallableResult:
) -> RemoteExecutionResult:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
@@ -238,7 +251,7 @@ class RemoteSeedProjectTask(RPCTask, SeedTask):
def handle_request(
self, params: RPCSeedProjectParameters
) -> RemoteCallableResult:
) -> RemoteExecutionResult:
self.args.show = params.show
results = self.run()
@@ -258,7 +271,7 @@ class RemoteTestProjectTask(RPCTask, TestTask):
def handle_request(
self, params: RPCTestProjectParameters,
) -> RemoteCallableResult:
) -> RemoteExecutionResult:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
self.args.data = params.data
@@ -281,13 +294,12 @@ class RemoteDocsGenerateProjectTask(RPCTask, GenerateTask):
def handle_request(
self, params: RPCDocsGenerateProjectParameters,
) -> RemoteCallableResult:
) -> RemoteCatalogResults:
self.args.models = None
self.args.exclude = None
self.args.compile = params.compile
results = self.run()
assert isinstance(results, RemoteCatalogResults)
return results
def get_catalog_results(