diff --git a/conn.go b/conn.go index b131e4c..0bb28ef 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ type conn struct { OutputLocation string pollFrequency time.Duration + workgroup string } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -65,6 +66,7 @@ func (c *conn) startQuery(ctx context.Context, query string) (string, error) { ResultConfiguration: &types.ResultConfiguration{ OutputLocation: aws.String(c.OutputLocation), }, + WorkGroup: aws.String(c.workgroup), }) if err != nil { return "", err diff --git a/driver.go b/driver.go index 03dceca..2119ebd 100644 --- a/driver.go +++ b/driver.go @@ -59,6 +59,9 @@ func init() { // - `region` (optional) // Override AWS region. Useful if it is not set with environment variable. // +// - `workgroup` (optional) +// Athena's workgroup. This defaults to "primary". +// // Credentials must be accessible via the SDK's Default Credential Provider Chain. // For more advanced AWS credentials/session/config management, please supply // a custom AWS session directly via `athena.Open()`. @@ -82,6 +85,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { db: cfg.Database, OutputLocation: cfg.OutputLocation, pollFrequency: cfg.PollFrequency, + workgroup: cfg.WorkGroup, }, nil } @@ -101,6 +105,10 @@ func Open(cfg DriverConfig) (*sql.DB, error) { return nil, errors.New("AWS config is required") } + if cfg.WorkGroup == "" { + cfg.WorkGroup = "primary" + } + // This hack was copied from jackc/pgx. Sorry :( // https://github.com/jackc/pgx/blob/70a284f4f33a9cc28fd1223f6b83fb00deecfe33/stdlib/sql.go#L130-L136 openFromSessionMutex.Lock() @@ -119,6 +127,7 @@ type DriverConfig struct { OutputLocation string PollFrequency time.Duration + WorkGroup string } func configFromConnectionString(ctx context.Context, connStr string) (*DriverConfig, error) { @@ -140,6 +149,10 @@ func configFromConnectionString(ctx context.Context, connStr string) (*DriverCon cfg.Database = args.Get("db") cfg.OutputLocation = args.Get("output_location") + cfg.WorkGroup = args.Get("workgroup") + if cfg.WorkGroup == "" { + cfg.WorkGroup = "primary" + } frequencyStr := args.Get("poll_frequency") if frequencyStr != "" {