23
23
from data_diff .queries .api import Expr , table , Select , SKIP , Explain , Code , this
24
24
from data_diff .queries .ast_classes import (
25
25
Alias ,
26
+ BinBoolOp ,
26
27
BinOp ,
27
28
CaseWhen ,
28
29
Cast ,
64
65
Float ,
65
66
Native_UUID ,
66
67
String_UUID ,
68
+ Binary_UUID ,
67
69
String_Alphanum ,
68
70
String_VaryingAlphanum ,
69
71
TemporalType ,
@@ -482,6 +484,22 @@ def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str:
482
484
def render__resolvecolumn (self , c : Compiler , elem : _ResolveColumn ) -> str :
483
485
return self .compile (c , elem ._get_resolved ())
484
486
487
+ def modify_string_where_clause (self , col , where_clause ):
488
+ # NOTE: snowflake specific issue with Binary columns
489
+ return where_clause .replace (f'"{ col } "' , f"TO_VARCHAR(\" { col } \" , 'UTF-8')" )
490
+
491
+ def check_for_binary_cols (self , where_exprs ):
492
+ binary_uuid_columns = set ()
493
+ for expr in where_exprs :
494
+ if isinstance (expr , BinBoolOp ):
495
+ for arg in expr .args :
496
+ if isinstance (arg , _ResolveColumn ):
497
+ resolved_column = arg .resolved
498
+ if isinstance (resolved_column , Column ) and resolved_column .source_table .schema :
499
+ if isinstance (resolved_column .type , Binary_UUID ):
500
+ binary_uuid_columns .add (resolved_column .name )
501
+ return binary_uuid_columns
502
+
485
503
def render_select (self , parent_c : Compiler , elem : Select ) -> str :
486
504
c : Compiler = attrs .evolve (parent_c , in_select = True ) # .add_table_context(self.table)
487
505
compile_fn = functools .partial (self .compile , c )
@@ -497,7 +515,13 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
497
515
select += f" FROM { self .PLACEHOLDER_TABLE } "
498
516
499
517
if elem .where_exprs :
500
- select += " WHERE " + " AND " .join (map (compile_fn , elem .where_exprs ))
518
+ where_clause = " WHERE " + " AND " .join (map (compile_fn , elem .where_exprs ))
519
+ # post processing step for snowfake BINARAY_UUID columns
520
+ if parent_c .dialect .name == "Snowflake" :
521
+ binary_uuids = self .check_for_binary_cols (elem .where_exprs )
522
+ for binary_uuid in binary_uuids :
523
+ where_clause = self .modify_string_where_clause (binary_uuid , where_clause )
524
+ select += where_clause
501
525
502
526
if elem .group_by_exprs :
503
527
select += " GROUP BY " + ", " .join (map (compile_fn , elem .group_by_exprs ))
@@ -836,6 +860,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
836
860
"""Creates an SQL expression, that strips uuids of artifacts like whitespace."""
837
861
if isinstance (coltype , String_UUID ):
838
862
return f"TRIM({ value } )"
863
+ # converts Binary to VARCHAR for Snowflake
864
+ elif isinstance (coltype , Binary_UUID ):
865
+ return f"TRIM(TO_VARCHAR({ value } , 'UTF-8'))"
839
866
return self .to_string (value )
840
867
841
868
def normalize_json (self , value : str , _coltype : JSON ) -> str :
0 commit comments