Skip to content

Implement NTLM-SSM/SPNEGO authentication #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ package main

import (
"bufio"
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
Expand All @@ -52,6 +53,8 @@ import (
"strings"
"time"

"github.com/Azure/go-ntlmssp"
"github.com/dpotapov/go-spnego"
"github.com/moriyoshi/mimetypes"
"github.com/pkg/errors"
"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -88,9 +91,24 @@ type MITMConfig struct {
DisableCache bool
}

type DirectSetter func(*http.Request) error

type UserPasswordPair struct {
User string
Password string
}

type RoundTripperWrapper func(http.RoundTripper, *http.Request) (*http.Response, error)

type AuthConfig struct {
RoundTripperWrapperFactory func() (RoundTripperWrapper, error)
CredentialsProvider func(context.Context) (interface{}, error)
}

type ProxyConfig struct {
HTTPProxy *url.URL
HTTPSProxy *url.URL
Auth AuthConfig
IncludedHosts []HostPortPair
ExcludedHosts []HostPortPair
TLSConfig *tls.Config
Expand Down Expand Up @@ -257,6 +275,75 @@ func parseUrlOrHostPortPair(urlOrHostPortPair string) (retval *url.URL, err erro
return
}

func (ctx *ConfigReaderContext) lookupRoundTripperWrapperFactory(typ string) (func() (RoundTripperWrapper, error), error) {
switch typ {
case "ntlm", "ntlm-ssp":
return func() (RoundTripperWrapper, error) {
return func(rt http.RoundTripper, req *http.Request) (*http.Response, error) {
return ntlmssp.Negotiator{rt}.RoundTrip(req)
}, nil
}, nil
case "gssapi", "spnego":
return func() (RoundTripperWrapper, error) {
p := spnego.New()
return func(rt http.RoundTripper, req *http.Request) (*http.Response, error) {
err := p.SetSPNEGOHeader(req)
if err != nil {
return nil, err
}
return rt.RoundTrip(req)
}, nil
}, nil
default:
return nil, errors.Errorf("unknown roundtripper wrapper: %s", typ)
}
}

func (ctx *ConfigReaderContext) extractAuthConfig(deref dereference) (retval AuthConfig, err error) {
err = deref.multi(
"type", func(typ string) error {
var err error
retval.RoundTripperWrapperFactory, err = ctx.lookupRoundTripperWrapperFactory(typ)
return err
},
"credentials", func(deref dereference) error {
var upp *UserPasswordPair
err := deref.multi(
"user", func(v string) error {
if upp == nil {
upp = &UserPasswordPair{}
}
upp.User = v
return nil
},
"password", func(v string) error {
if upp == nil {
upp = &UserPasswordPair{}
}
upp.Password = v
return nil
},
)
if err != nil {
return err
}
if upp != nil {
retval.CredentialsProvider = func(_ context.Context) (interface{}, error) {
return upp, nil
}
} else {
retval.CredentialsProvider = func(_ context.Context) (interface{}, error) {
return func(_ *http.Request) error {
return nil
}, nil
}
}
return nil
},
)
return
}

func (ctx *ConfigReaderContext) extractProxyConfig(deref dereference) (retval ProxyConfig, err error) {
err = deref.multi(
"proxy", func(deref dereference) error {
Expand All @@ -279,6 +366,11 @@ func (ctx *ConfigReaderContext) extractProxyConfig(deref dereference) (retval Pr
retval.HTTPSProxy = httpsProxyUrl
return nil
},
"auth", func(deref dereference) error {
var err error
retval.Auth, err = ctx.extractAuthConfig(deref)
return err
},
"included", func(includedHosts []string) error {
retval.IncludedHosts, err = convertUnparsedHostsIntoPairs(includedHosts)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
proxy:
http: http://127.0.0.1:9080/
https: http://127.0.0.1:9080/
auth:
type: ntlm
credentials:
user: DOMAIN\\FOO
password: PASS
excluded:
- localhost:8081
- localhost:8082
Expand Down
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
module github.com/moriyoshi/devproxy

require (
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c
github.com/dpotapov/go-spnego v0.0.0-20190506202455-c2c609116ad0
github.com/moriyoshi/mimetypes v1.0.0
github.com/moriyoshi/simplefiletx v1.0.0
github.com/pkg/errors v0.9.1
github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0
github.com/sirupsen/logrus v1.3.0
github.com/stretchr/testify v1.2.2
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3
golang.org/x/text v0.3.0 // indirect
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3
gopkg.in/yaml.v2 v2.3.0
)

Expand Down
24 changes: 24 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28=
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dpotapov/go-spnego v0.0.0-20190506202455-c2c609116ad0 h1:Hhh7nu7CfFVlnBJqmDDUh+j1H5fqjLMzM4czZzNNJGM=
github.com/dpotapov/go-spnego v0.0.0-20190506202455-c2c609116ad0/go.mod h1:P4f4MSk7h52F2PK0lCapn5+fu47Uf8aRdxDSqgezxZE=
github.com/hashicorp/go-uuid v1.0.1 h1:fv1ep09latC32wFoVwnqcnKJGnMSdBanPczbHAYm1BE=
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03 h1:FUwcHNlEqkqLjLBdCp5PRlCFijNjvcYANOZXzCfXwCM=
github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/moriyoshi/mimetypes v1.0.0 h1:nESQmdWurua/+7QzWnxMpbHYUd0mMsi6zKAkq3ZbU50=
Expand All @@ -19,13 +27,29 @@ github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 h1:p/H982KKEjUnLJkM3tt/LemDnOc1GiZL5FCVlORJ5zo=
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3 h1:ulvT7fqt0yHWzpJwI57MezWnYDVpCAYBVuYst/L+fAY=
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hrEw=
gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo=
gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM=
gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q=
gopkg.in/jcmturner/gokrb5.v5 v5.3.0 h1:RS1MYApX27Hx1Xw7NECs7XxGxxrm69/4OmaRuX9kwec=
gopkg.in/jcmturner/gokrb5.v5 v5.3.0/go.mod h1:oQz8Wc5GsctOTgCVyKad1Vw4TCWz5G6gfIQr88RPv4k=
gopkg.in/jcmturner/rpc.v0 v0.0.2 h1:wBTgrbL1qmLBUPsYVCqdJiI5aJgQhexmK+JkTHPUNJI=
gopkg.in/jcmturner/rpc.v0 v0.0.2/go.mod h1:NzMq6cRzR9lipgw7WxRBHNx5N8SifBuaCQsOT1kWY/E=
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
8 changes: 6 additions & 2 deletions httpx/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1406,8 +1406,12 @@ func addTLS(t *Transport, tlsConfig *tls.Config, name string, conn net.Conn, tra

func (t *Transport) doDialFirstHop(ctx context.Context, cm ConnectMethod, trace *httptrace.ClientTrace) (conn net.Conn, tlsState *tls.ConnectionState, err error) {
firstHopScheme := cm.Scheme()
if firstHopScheme == "https" && t.DialTLS != nil {
conn, err = t.DialTLS(ctx, "tcp", cm.Addr())
if firstHopScheme == "https" && (t.DialTLS2 != nil || t.DialTLS != nil) {
if t.DialTLS2 != nil {
conn, err = t.DialTLS2(ctx, "tcp", cm.Addr(), t.TLSClientConfig)
} else {
conn, err = t.DialTLS(ctx, "tcp", cm.Addr())
}
if err != nil {
return
}
Expand Down
119 changes: 108 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
package main

import (
"context"
crand "crypto/rand"
"crypto/sha1"
"crypto/tls"
Expand Down Expand Up @@ -64,6 +65,81 @@ import (

type TLSConfigFactory func(hostPortPairStr string, proxyCtx *OurProxyCtx) (*tls.Config, error)

type HttpxTransport interface {
http.RoundTripper
RegisterProtocol(string, http.RoundTripper)
CloseIdleConnections()
CancelRequest(*http.Request, error)
ConnectMethodForRequest(*httpx.TransportRequest) (httpx.ConnectMethod, error)
DoDial(context.Context, httpx.ConnectMethod) (net.Conn, *tls.ConnectionState, bool, func(http.Header), error)
DialContext(context.Context, string, string) (net.Conn, error)
DialTLS2(context.Context, string, string, *tls.Config) (net.Conn, error)
}

type httpxTransportWrapper struct {
*httpx.Transport
}

func (htw *httpxTransportWrapper) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if htw.Transport.DialContext != nil {
return htw.Transport.DialContext(ctx, network, addr)
}
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
}

func (htw *httpxTransportWrapper) DialTLS2(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
if htw.Transport.DialTLS2 != nil {
return htw.Transport.DialTLS2(ctx, network, addr, config)
}
if config == nil {
if htw.Transport.DialTLS != nil {
return htw.Transport.DialTLS(ctx, "tcp", addr)
} else {
config = htw.Transport.TLSClientConfig
}
}
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, errors.Wrapf(err, "failed to connect to %v", addr)
}
tlsConfig := config.Clone()
tlsConfig.ServerName = splitHostPort(addr).Host
return tls.Client(conn, tlsConfig), nil
}

func trAsInterface(tr *httpx.Transport) HttpxTransport {
return &httpxTransportWrapper{tr}
}

type transportRoundTripperWrapper struct {
HttpxTransport
wrapper RoundTripperWrapper
}

func (ttw *transportRoundTripperWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
return ttw.wrapper(ttw, req)
}

type roundTripperWrapper struct {
http.RoundTripper
wrapper RoundTripperWrapper
}

func (rtw *roundTripperWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
return rtw.wrapper(rtw, req)
}

func wrapRoundTripper(tr http.RoundTripper, rtw RoundTripperWrapper) (http.RoundTripper, error) {
switch tr := tr.(type) {
case HttpxTransport:
return &transportRoundTripperWrapper{tr, rtw}, nil
case http.RoundTripper:
return &roundTripperWrapper{tr, rtw}, nil
default:
return nil, errors.Errorf("invalid round tripper: %T", tr)
}
}

type DevProxy struct {
Logger *logrus.Logger
LogWriter io.WriteCloser
Expand Down Expand Up @@ -180,15 +256,28 @@ func (ctx *DevProxy) newProxyURLBuilder() func(*http.Request) (*url.URL, *tls.Co
}
}

func (ctx *DevProxy) newHttpTransport() *httpx.Transport {
transport := &httpx.Transport{
func (ctx *DevProxy) newHttpTransport() (tr HttpxTransport, err error) {
tr = trAsInterface(&httpx.Transport{
TLSClientConfig: ctx.Config.MITM.ClientTLSConfigTemplate,
Proxy2: ctx.newProxyURLBuilder(),
})
tr.RegisterProtocol("fastcgi", &fastCGIRoundTripper{Logger: ctx.Logger})
tr.RegisterProtocol("file", NewFileTransport(ctx.Config.FileTransport))
tr.RegisterProtocol("x-http-redirect", &redirector{Logger: ctx.Logger})
if rtwf := ctx.Config.Proxy.Auth.RoundTripperWrapperFactory; rtwf != nil {
var rt RoundTripperWrapper
rt, err = rtwf()
if err != nil {
return
}
var wrt http.RoundTripper
wrt, err = wrapRoundTripper(tr, rt)
if err != nil {
return
}
tr = wrt.(HttpxTransport)
}
transport.RegisterProtocol("fastcgi", &fastCGIRoundTripper{Logger: ctx.Logger})
transport.RegisterProtocol("file", NewFileTransport(ctx.Config.FileTransport))
transport.RegisterProtocol("x-http-redirect", &redirector{Logger: ctx.Logger})
return transport
return
}

var domainNameRegex = regexp.MustCompile("^[A-Za-z](?:[0-9A-Za-z-_]*[0-9A-Za-z])?$")
Expand Down Expand Up @@ -366,15 +455,19 @@ func (ctx *DevProxy) checkIfTunnelRequestMatchesToUrl(url_ *url.URL, req *http.R
return false
}

func (ctx *DevProxy) newProxyHttpServer() *OurProxyHttpServer {
func (ctx *DevProxy) newProxyHttpServer() (*OurProxyHttpServer, error) {
tr, err := ctx.newHttpTransport()
if err != nil {
return nil, err
}
return &OurProxyHttpServer{
Ctx: ctx,
Logger: ctx.Logger,
Tr: ctx.newHttpTransport(),
Tr: tr,
TLSConfigFactory: ctx.newTLSConfigFactory(),
ResponseFilters: ctx.Config.ResponseFilters,
SessionSerial: 0,
}
}, nil
}

func (ctx *DevProxy) Dispose() {
Expand Down Expand Up @@ -445,7 +538,11 @@ func main() {
),
}
defer ctx.Dispose()
proxy := ctx.newProxyHttpServer()
logger.Infof("Listening on %s...", listenOn)
proxy, err := ctx.newProxyHttpServer()
if err != nil {
logger.Fatalf("could not initialize the proxy server: %s", err.Error())
os.Exit(1)
}
logger.Infof("listening on %s...", listenOn)
logger.Fatal(http.ListenAndServe(listenOn, proxy))
}
Loading