|
1 |
| -# class MsSQL(ThreadedDatabase): |
2 |
| -# "AKA sql-server" |
| 1 | +from typing import Optional |
| 2 | +from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue |
| 3 | +from .base import ( |
| 4 | + CHECKSUM_HEXDIGITS, |
| 5 | + Mixin_OptimizerHints, |
| 6 | + Mixin_RandomSample, |
| 7 | + QueryError, |
| 8 | + ThreadedDatabase, |
| 9 | + import_helper, |
| 10 | + ConnectError, |
| 11 | + BaseDialect, |
| 12 | +) |
| 13 | +from .base import Mixin_Schema |
| 14 | +from ..abcs.database_types import ( |
| 15 | + JSON, |
| 16 | + Timestamp, |
| 17 | + TimestampTZ, |
| 18 | + DbPath, |
| 19 | + Float, |
| 20 | + Decimal, |
| 21 | + Integer, |
| 22 | + TemporalType, |
| 23 | + Native_UUID, |
| 24 | + Text, |
| 25 | + FractionalType, |
| 26 | + Boolean, |
| 27 | +) |
3 | 28 |
|
4 |
| -# def __init__(self, host, port, user, password, *, database, thread_count, **kw): |
5 |
| -# args = dict(server=host, port=port, database=database, user=user, password=password, **kw) |
6 |
| -# self._args = {k: v for k, v in args.items() if v is not None} |
7 | 29 |
|
8 |
| -# super().__init__(thread_count=thread_count) |
| 30 | +@import_helper("mssql") |
| 31 | +def import_mssql(): |
| 32 | + import pyodbc |
9 | 33 |
|
10 |
| -# def create_connection(self): |
11 |
| -# mssql = import_mssql() |
12 |
| -# try: |
13 |
| -# return mssql.connect(**self._args) |
14 |
| -# except mssql.Error as e: |
15 |
| -# raise ConnectError(*e.args) from e |
| 34 | + return pyodbc |
16 | 35 |
|
17 |
| -# def quote(self, s: str): |
18 |
| -# return f"[{s}]" |
19 | 36 |
|
20 |
| -# def md5_as_int(self, s: str) -> str: |
21 |
| -# return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" |
22 |
| -# # return f"CONVERT(bigint, (CHECKSUM({s})))" |
| 37 | +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): |
| 38 | + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: |
| 39 | + if coltype.precision > 0: |
| 40 | + formatted_value = ( |
| 41 | + f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + " |
| 42 | + f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})" |
| 43 | + ) |
| 44 | + else: |
| 45 | + formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')" |
23 | 46 |
|
24 |
| -# def to_string(self, s: str): |
25 |
| -# return f"CONVERT(varchar, {s})" |
| 47 | + return formatted_value |
| 48 | + |
| 49 | + def normalize_number(self, value: str, coltype: FractionalType) -> str: |
| 50 | + if coltype.precision == 0: |
| 51 | + return f"CAST(FLOOR({value}) AS VARCHAR)" |
| 52 | + |
| 53 | + return f"FORMAT({value}, 'N{coltype.precision}')" |
| 54 | + |
| 55 | + |
| 56 | +class Mixin_MD5(AbstractMixin_MD5): |
| 57 | + def md5_as_int(self, s: str) -> str: |
| 58 | + return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" |
| 59 | + |
| 60 | + |
| 61 | +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): |
| 62 | + name = "MsSQL" |
| 63 | + ROUNDS_ON_PREC_LOSS = True |
| 64 | + SUPPORTS_PRIMARY_KEY = True |
| 65 | + SUPPORTS_INDEXES = True |
| 66 | + TYPE_CLASSES = { |
| 67 | + # Timestamps |
| 68 | + "datetimeoffset": TimestampTZ, |
| 69 | + "datetime": Timestamp, |
| 70 | + "datetime2": Timestamp, |
| 71 | + "smalldatetime": Timestamp, |
| 72 | + "date": Timestamp, |
| 73 | + # Numbers |
| 74 | + "float": Float, |
| 75 | + "real": Float, |
| 76 | + "decimal": Decimal, |
| 77 | + "money": Decimal, |
| 78 | + "smallmoney": Decimal, |
| 79 | + # int |
| 80 | + "int": Integer, |
| 81 | + "bigint": Integer, |
| 82 | + "tinyint": Integer, |
| 83 | + "smallint": Integer, |
| 84 | + # Text |
| 85 | + "varchar": Text, |
| 86 | + "char": Text, |
| 87 | + "text": Text, |
| 88 | + "ntext": Text, |
| 89 | + "nvarchar": Text, |
| 90 | + "nchar": Text, |
| 91 | + "binary": Text, |
| 92 | + "varbinary": Text, |
| 93 | + # UUID |
| 94 | + "uniqueidentifier": Native_UUID, |
| 95 | + # Bool |
| 96 | + "bit": Boolean, |
| 97 | + # JSON |
| 98 | + "json": JSON, |
| 99 | + } |
| 100 | + |
| 101 | + MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample} |
| 102 | + |
| 103 | + def quote(self, s: str): |
| 104 | + return f"[{s}]" |
| 105 | + |
| 106 | + def set_timezone_to_utc(self) -> str: |
| 107 | + raise NotImplementedError("MsSQL does not support a session timezone setting.") |
| 108 | + |
| 109 | + def current_timestamp(self) -> str: |
| 110 | + return "GETDATE()" |
| 111 | + |
| 112 | + def current_database(self) -> str: |
| 113 | + return "DB_NAME()" |
| 114 | + |
| 115 | + def current_schema(self) -> str: |
| 116 | + return """default_schema_name |
| 117 | + FROM sys.database_principals |
| 118 | + WHERE name = CURRENT_USER""" |
| 119 | + |
| 120 | + def to_string(self, s: str): |
| 121 | + return f"CONVERT(varchar, {s})" |
| 122 | + |
| 123 | + def type_repr(self, t) -> str: |
| 124 | + try: |
| 125 | + return {bool: "bit"}[t] |
| 126 | + except KeyError: |
| 127 | + return super().type_repr(t) |
| 128 | + |
| 129 | + def random(self) -> str: |
| 130 | + return "rand()" |
| 131 | + |
| 132 | + def is_distinct_from(self, a: str, b: str) -> str: |
| 133 | + # IS (NOT) DISTINCT FROM is available only since SQLServer 2022. |
| 134 | + # See: https://stackoverflow.com/a/18684859/857383 |
| 135 | + return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))" |
| 136 | + |
| 137 | + def offset_limit( |
| 138 | + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None |
| 139 | + ) -> str: |
| 140 | + if offset: |
| 141 | + raise NotImplementedError("No support for OFFSET in query") |
| 142 | + |
| 143 | + result = "" |
| 144 | + if not has_order_by: |
| 145 | + result += "ORDER BY 1" |
| 146 | + |
| 147 | + result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" |
| 148 | + return result |
| 149 | + |
| 150 | + def constant_values(self, rows) -> str: |
| 151 | + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) |
| 152 | + return f"VALUES {values}" |
| 153 | + |
| 154 | + |
| 155 | +class MsSQL(ThreadedDatabase): |
| 156 | + dialect = Dialect() |
| 157 | + # |
| 158 | + CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>" |
| 159 | + CONNECT_URI_PARAMS = ["database", "schema"] |
| 160 | + |
| 161 | + def __init__(self, host, port, user, password, *, database, thread_count, **kw): |
| 162 | + args = dict(server=host, port=port, database=database, user=user, password=password, **kw) |
| 163 | + self._args = {k: v for k, v in args.items() if v is not None} |
| 164 | + self._args["driver"] = "{ODBC Driver 18 for SQL Server}" |
| 165 | + |
| 166 | + # TODO temp dev debug |
| 167 | + self._args["TrustServerCertificate"] = "yes" |
| 168 | + |
| 169 | + try: |
| 170 | + self.default_database = self._args["database"] |
| 171 | + self.default_schema = self._args["schema"] |
| 172 | + except KeyError: |
| 173 | + raise ValueError("Specify a default database and schema.") |
| 174 | + |
| 175 | + super().__init__(thread_count=thread_count) |
| 176 | + |
| 177 | + def create_connection(self): |
| 178 | + self._mssql = import_mssql() |
| 179 | + try: |
| 180 | + connection = self._mssql.connect(**self._args) |
| 181 | + return connection |
| 182 | + except self._mssql.Error as error: |
| 183 | + raise ConnectError(*error.args) from error |
| 184 | + |
| 185 | + def select_table_schema(self, path: DbPath) -> str: |
| 186 | + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" |
| 187 | + database, schema, name = self._normalize_table_path(path) |
| 188 | + info_schema_path = ["information_schema", "columns"] |
| 189 | + if database: |
| 190 | + info_schema_path.insert(0, self.dialect.quote(database)) |
| 191 | + |
| 192 | + return ( |
| 193 | + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " |
| 194 | + f"FROM {'.'.join(info_schema_path)} " |
| 195 | + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" |
| 196 | + ) |
| 197 | + |
| 198 | + def _normalize_table_path(self, path: DbPath) -> DbPath: |
| 199 | + if len(path) == 1: |
| 200 | + return self.default_database, self.default_schema, path[0] |
| 201 | + elif len(path) == 2: |
| 202 | + return self.default_database, path[0], path[1] |
| 203 | + elif len(path) == 3: |
| 204 | + return path |
| 205 | + |
| 206 | + raise ValueError( |
| 207 | + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" |
| 208 | + ) |
| 209 | + |
| 210 | + def _query_cursor(self, c, sql_code: str): |
| 211 | + try: |
| 212 | + return super()._query_cursor(c, sql_code) |
| 213 | + except self._mssql.DatabaseError as e: |
| 214 | + raise QueryError(e) |
0 commit comments