-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.go
120 lines (102 loc) · 3.02 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package main
import (
"fmt"
"io"
"net"
"os"
"os/signal"
"strings"
"syscall"
pg_query "github.com/pganalyze/pg_query_go/v2"
"github.com/rueian/pgbroker/backend"
"github.com/rueian/pgbroker/message"
"github.com/rueian/pgbroker/proxy"
)
func sanitizeQueryStr(queryString string) string {
sanitizedQueryStr := strings.Replace(queryString, "\n", " ", -1)
sanitizedQueryStr = strings.Replace(sanitizedQueryStr, "\r", "", -1)
return sanitizedQueryStr
}
func generateErrorQuery(msg string) string {
// return a query that triggers an error on the server that contains our
// desired error message
escapedMsg := strings.Replace(msg, "'", "''", -1)
escapedMsg = strings.Replace(escapedMsg, `\`, `\\`, -1)
return fmt.Sprintf("'%s';", escapedMsg)
}
func main() {
inShutdown := false
var configPath string
if len(os.Args) == 2 {
configPath = os.Args[1]
} else {
configPath = "pgfilterproxy.yaml"
}
if err := loadConfig(configPath); err != nil {
panic(err)
}
ln, err := net.Listen("tcp", GetConfig().Listen)
if err != nil {
panic(err)
}
clientMessageHandlers := proxy.NewClientMessageHandlers()
clientMessageHandlers.AddHandleQuery(func(ctx *proxy.Ctx, msg *message.Query) (query *message.Query, e error) {
fingerprint, err := pg_query.Fingerprint(msg.QueryString)
if err != nil {
fmt.Printf("failed to parse query: %v: %s\n", err, sanitizeQueryStr(msg.QueryString))
msg.QueryString = generateErrorQuery(fmt.Sprintf("failed to parse query: %v", err))
return msg, nil
}
_, ok := GetConfig().AllowedFingerprints[fingerprint]
if !ok {
fmt.Printf("query with finterprint %s not allowed: %s\n", fingerprint, sanitizeQueryStr(msg.QueryString))
msg.QueryString = generateErrorQuery("query is not allowed")
return msg, nil
}
return msg, nil
})
clientMessageHandlers.AddHandleClientOther(func(ctx *proxy.Ctx, msg *message.Raw) (raw *message.Raw, e error) {
_, ok := GetConfig().AllowedCommands[msg.Type]
if !ok {
return nil, fmt.Errorf("disallowed client command %c", msg.Type)
}
return msg, nil
})
serverStreamCallbackFactories := proxy.NewStreamCallbackFactories()
server := &proxy.Server{
PGResolver: backend.NewStaticPGResolver(GetConfig().TargetServer),
ConnInfoStore: backend.NewInMemoryConnInfoStore(),
ClientMessageHandlers: clientMessageHandlers,
ServerStreamCallbackFactories: serverStreamCallbackFactories,
OnHandleConnError: func(err error, ctx *proxy.Ctx, conn net.Conn) {
if err == io.EOF {
return
}
fmt.Println("OnHandleConnError", err)
},
}
go func() {
if err := server.Serve(ln); err != nil {
if !inShutdown {
panic(err)
}
}
}()
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
for {
signal := <-sigs
if signal == syscall.SIGHUP {
if err := loadConfig(configPath); err != nil {
fmt.Println(err)
} else {
fmt.Println("reloaded config")
}
} else {
fmt.Println("shutting down")
break
}
}
inShutdown = true
server.Shutdown()
}