Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 24cd349

Browse files
authored
Merge pull request #696 from nolar/LAB-144-mssql-sv
Support MSSQL for cross-database diffs
2 parents 4ef9cb1 + 400a825 commit 24cd349

18 files changed

+342
-33
lines changed

data_diff/databases/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
from .clickhouse import Clickhouse
1313
from .vertica import Vertica
1414
from .duckdb import DuckDB
15+
from .mssql import MsSql
1516

1617
from ._connect import connect

data_diff/databases/_connect.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .clickhouse import Clickhouse
1515
from .vertica import Vertica
1616
from .duckdb import DuckDB
17+
from .mssql import MsSql
1718

1819

1920
DATABASE_BY_SCHEME = {
@@ -29,6 +30,7 @@
2930
"trino": Trino,
3031
"clickhouse": Clickhouse,
3132
"vertica": Vertica,
33+
"mssql": MsSql,
3234
}
3335

3436

data_diff/databases/mssql.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from data_diff.sqeleton.databases import mssql
2+
from .base import DatadiffDialect
3+
4+
5+
class Dialect(mssql.Dialect, mssql.Mixin_MD5, mssql.Mixin_NormalizeValue, DatadiffDialect):
6+
pass
7+
8+
9+
class MsSql(mssql.MsSQL):
10+
dialect = Dialect()

data_diff/joindiff_tables.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from runtype import dataclass
1212

13-
from data_diff.sqeleton.databases import Database, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath
13+
from data_diff.sqeleton.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath
1414
from data_diff.sqeleton.abcs import NumericType
1515
from data_diff.sqeleton.queries import (
1616
table,
@@ -25,9 +25,10 @@
2525
leftjoin,
2626
rightjoin,
2727
this,
28+
when,
2829
Compiler,
2930
)
30-
from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable
31+
from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Func, Random, TablePath, Code, ITable
3132
from data_diff.sqeleton.queries.extras import NormalizeAsString
3233

3334
from .info_tree import InfoTree
@@ -82,6 +83,12 @@ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List
8283

8384
is_exclusive_a = and_(b[k] == None for k in keys2)
8485
is_exclusive_b = and_(a[k] == None for k in keys1)
86+
87+
if isinstance(db, MsSQL):
88+
# There is no "IS NULL" or "ISNULL()" as expressions, only as conditions.
89+
is_exclusive_a = when(is_exclusive_a).then(1).else_(0)
90+
is_exclusive_b = when(is_exclusive_b).then(1).else_(0)
91+
8592
if isinstance(db, Oracle):
8693
is_exclusive_a = bool_to_int(is_exclusive_a)
8794
is_exclusive_b = bool_to_int(is_exclusive_b)
@@ -342,7 +349,7 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
342349
self.stats["diff_counts"] = diff_counts
343350

344351
def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
345-
if isinstance(db, Oracle):
352+
if isinstance(db, (Oracle, MsSQL)):
346353
exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
347354
else:
348355
exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)

data_diff/sqeleton/abcs/database_types.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,17 @@ def current_timestamp(self) -> str:
216216
"Provide SQL for returning the current timestamp, aka now"
217217

218218
@abstractmethod
219-
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
219+
def current_database(self) -> str:
220+
"Provide SQL for returning the current default database."
221+
222+
@abstractmethod
223+
def current_schema(self) -> str:
224+
"Provide SQL for returning the current default schema."
225+
226+
@abstractmethod
227+
def offset_limit(
228+
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
229+
) -> str:
220230
"Provide SQL fragment for limit and offset inside a select"
221231

222232
@abstractmethod

data_diff/sqeleton/databases/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
from .clickhouse import Clickhouse
1515
from .vertica import Vertica
1616
from .duckdb import DuckDB
17+
from .mssql import MsSQL
1718

1819
connect = Connect()

data_diff/sqeleton/databases/_connect.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .clickhouse import Clickhouse
2222
from .vertica import Vertica
2323
from .duckdb import DuckDB
24+
from .mssql import MsSQL
2425

2526

2627
@dataclass
@@ -86,6 +87,7 @@ def match_path(self, dsn):
8687
"trino": Trino,
8788
"clickhouse": Clickhouse,
8889
"vertica": Vertica,
90+
"mssql": MsSQL,
8991
}
9092

9193

data_diff/sqeleton/databases/base.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ class BaseDialect(AbstractDialect):
155155

156156
PLACEHOLDER_TABLE = None # Used for Oracle
157157

158-
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
158+
def offset_limit(
159+
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
160+
) -> str:
159161
if offset:
160162
raise NotImplementedError("No support for OFFSET in query")
161163

@@ -182,6 +184,12 @@ def random(self) -> str:
182184
def current_timestamp(self) -> str:
183185
return "current_timestamp()"
184186

187+
def current_database(self) -> str:
188+
return "current_database()"
189+
190+
def current_schema(self) -> str:
191+
return "current_schema()"
192+
185193
def explain_as_text(self, query: str) -> str:
186194
return f"EXPLAIN {query}"
187195

@@ -518,7 +526,10 @@ def _query_cursor(self, c, sql_code: str) -> QueryResult:
518526
c.execute(sql_code)
519527
if sql_code.lower().startswith(("select", "explain", "show")):
520528
columns = [col[0] for col in c.description]
521-
return QueryResult(c.fetchall(), columns)
529+
530+
fetched = c.fetchall()
531+
result = QueryResult(fetched, columns)
532+
return result
522533
except Exception as _e:
523534
# logger.exception(e)
524535
# logger.error(f'Caused by SQL: {sql_code}')
@@ -590,7 +601,8 @@ def is_autocommit(self) -> bool:
590601
return False
591602

592603

593-
CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower, otherwise SUM() overflows
604+
# TODO FYI mssql md5_as_int currently requires this to be reduced
605+
CHECKSUM_HEXDIGITS = 14 # Must be 15 or lower, otherwise SUM() overflows
594606
MD5_HEXDIGITS = 32
595607

596608
_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2

data_diff/sqeleton/databases/mssql.py

+208-19
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,214 @@
1-
# class MsSQL(ThreadedDatabase):
2-
# "AKA sql-server"
1+
from typing import Optional
2+
from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
3+
from .base import (
4+
CHECKSUM_HEXDIGITS,
5+
Mixin_OptimizerHints,
6+
Mixin_RandomSample,
7+
QueryError,
8+
ThreadedDatabase,
9+
import_helper,
10+
ConnectError,
11+
BaseDialect,
12+
)
13+
from .base import Mixin_Schema
14+
from ..abcs.database_types import (
15+
JSON,
16+
Timestamp,
17+
TimestampTZ,
18+
DbPath,
19+
Float,
20+
Decimal,
21+
Integer,
22+
TemporalType,
23+
Native_UUID,
24+
Text,
25+
FractionalType,
26+
Boolean,
27+
)
328

4-
# def __init__(self, host, port, user, password, *, database, thread_count, **kw):
5-
# args = dict(server=host, port=port, database=database, user=user, password=password, **kw)
6-
# self._args = {k: v for k, v in args.items() if v is not None}
729

8-
# super().__init__(thread_count=thread_count)
30+
@import_helper("mssql")
31+
def import_mssql():
32+
import pyodbc
933

10-
# def create_connection(self):
11-
# mssql = import_mssql()
12-
# try:
13-
# return mssql.connect(**self._args)
14-
# except mssql.Error as e:
15-
# raise ConnectError(*e.args) from e
34+
return pyodbc
1635

17-
# def quote(self, s: str):
18-
# return f"[{s}]"
1936

20-
# def md5_as_int(self, s: str) -> str:
21-
# return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))"
22-
# # return f"CONVERT(bigint, (CHECKSUM({s})))"
37+
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
38+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
39+
if coltype.precision > 0:
40+
formatted_value = (
41+
f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + "
42+
f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})"
43+
)
44+
else:
45+
formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')"
2346

24-
# def to_string(self, s: str):
25-
# return f"CONVERT(varchar, {s})"
47+
return formatted_value
48+
49+
def normalize_number(self, value: str, coltype: FractionalType) -> str:
50+
if coltype.precision == 0:
51+
return f"CAST(FLOOR({value}) AS VARCHAR)"
52+
53+
return f"FORMAT({value}, 'N{coltype.precision}')"
54+
55+
56+
class Mixin_MD5(AbstractMixin_MD5):
57+
def md5_as_int(self, s: str) -> str:
58+
return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))"
59+
60+
61+
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
62+
name = "MsSQL"
63+
ROUNDS_ON_PREC_LOSS = True
64+
SUPPORTS_PRIMARY_KEY = True
65+
SUPPORTS_INDEXES = True
66+
TYPE_CLASSES = {
67+
# Timestamps
68+
"datetimeoffset": TimestampTZ,
69+
"datetime": Timestamp,
70+
"datetime2": Timestamp,
71+
"smalldatetime": Timestamp,
72+
"date": Timestamp,
73+
# Numbers
74+
"float": Float,
75+
"real": Float,
76+
"decimal": Decimal,
77+
"money": Decimal,
78+
"smallmoney": Decimal,
79+
# int
80+
"int": Integer,
81+
"bigint": Integer,
82+
"tinyint": Integer,
83+
"smallint": Integer,
84+
# Text
85+
"varchar": Text,
86+
"char": Text,
87+
"text": Text,
88+
"ntext": Text,
89+
"nvarchar": Text,
90+
"nchar": Text,
91+
"binary": Text,
92+
"varbinary": Text,
93+
# UUID
94+
"uniqueidentifier": Native_UUID,
95+
# Bool
96+
"bit": Boolean,
97+
# JSON
98+
"json": JSON,
99+
}
100+
101+
MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample}
102+
103+
def quote(self, s: str):
104+
return f"[{s}]"
105+
106+
def set_timezone_to_utc(self) -> str:
107+
raise NotImplementedError("MsSQL does not support a session timezone setting.")
108+
109+
def current_timestamp(self) -> str:
110+
return "GETDATE()"
111+
112+
def current_database(self) -> str:
113+
return "DB_NAME()"
114+
115+
def current_schema(self) -> str:
116+
return """default_schema_name
117+
FROM sys.database_principals
118+
WHERE name = CURRENT_USER"""
119+
120+
def to_string(self, s: str):
121+
return f"CONVERT(varchar, {s})"
122+
123+
def type_repr(self, t) -> str:
124+
try:
125+
return {bool: "bit"}[t]
126+
except KeyError:
127+
return super().type_repr(t)
128+
129+
def random(self) -> str:
130+
return "rand()"
131+
132+
def is_distinct_from(self, a: str, b: str) -> str:
133+
# IS (NOT) DISTINCT FROM is available only since SQLServer 2022.
134+
# See: https://stackoverflow.com/a/18684859/857383
135+
return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))"
136+
137+
def offset_limit(
138+
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
139+
) -> str:
140+
if offset:
141+
raise NotImplementedError("No support for OFFSET in query")
142+
143+
result = ""
144+
if not has_order_by:
145+
result += "ORDER BY 1"
146+
147+
result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY"
148+
return result
149+
150+
def constant_values(self, rows) -> str:
151+
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
152+
return f"VALUES {values}"
153+
154+
155+
class MsSQL(ThreadedDatabase):
156+
dialect = Dialect()
157+
#
158+
CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>"
159+
CONNECT_URI_PARAMS = ["database", "schema"]
160+
161+
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
162+
args = dict(server=host, port=port, database=database, user=user, password=password, **kw)
163+
self._args = {k: v for k, v in args.items() if v is not None}
164+
self._args["driver"] = "{ODBC Driver 18 for SQL Server}"
165+
166+
# TODO temp dev debug
167+
self._args["TrustServerCertificate"] = "yes"
168+
169+
try:
170+
self.default_database = self._args["database"]
171+
self.default_schema = self._args["schema"]
172+
except KeyError:
173+
raise ValueError("Specify a default database and schema.")
174+
175+
super().__init__(thread_count=thread_count)
176+
177+
def create_connection(self):
178+
self._mssql = import_mssql()
179+
try:
180+
connection = self._mssql.connect(**self._args)
181+
return connection
182+
except self._mssql.Error as error:
183+
raise ConnectError(*error.args) from error
184+
185+
def select_table_schema(self, path: DbPath) -> str:
186+
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
187+
database, schema, name = self._normalize_table_path(path)
188+
info_schema_path = ["information_schema", "columns"]
189+
if database:
190+
info_schema_path.insert(0, self.dialect.quote(database))
191+
192+
return (
193+
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
194+
f"FROM {'.'.join(info_schema_path)} "
195+
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
196+
)
197+
198+
def _normalize_table_path(self, path: DbPath) -> DbPath:
199+
if len(path) == 1:
200+
return self.default_database, self.default_schema, path[0]
201+
elif len(path) == 2:
202+
return self.default_database, path[0], path[1]
203+
elif len(path) == 3:
204+
return path
205+
206+
raise ValueError(
207+
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
208+
)
209+
210+
def _query_cursor(self, c, sql_code: str):
211+
try:
212+
return super()._query_cursor(c, sql_code)
213+
except self._mssql.DatabaseError as e:
214+
raise QueryError(e)

0 commit comments

Comments
 (0)