summaryrefslogtreecommitdiff
path: root/vendor/github.com/aws/smithy-go/private/requestcompression
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/aws/smithy-go/private/requestcompression')
-rw-r--r--vendor/github.com/aws/smithy-go/private/requestcompression/gzip.go30
-rw-r--r--vendor/github.com/aws/smithy-go/private/requestcompression/middleware_capture_request_compression.go52
-rw-r--r--vendor/github.com/aws/smithy-go/private/requestcompression/request_compression.go103
3 files changed, 185 insertions, 0 deletions
diff --git a/vendor/github.com/aws/smithy-go/private/requestcompression/gzip.go b/vendor/github.com/aws/smithy-go/private/requestcompression/gzip.go
new file mode 100644
index 0000000..004d78f
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/private/requestcompression/gzip.go
@@ -0,0 +1,30 @@
+package requestcompression
+
+import (
+ "bytes"
+ "compress/gzip"
+ "fmt"
+ "io"
+)
+
+func gzipCompress(input io.Reader) ([]byte, error) {
+ var b bytes.Buffer
+ w, err := gzip.NewWriterLevel(&b, gzip.DefaultCompression)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create gzip writer, %v", err)
+ }
+
+ inBytes, err := io.ReadAll(input)
+ if err != nil {
+ return nil, fmt.Errorf("failed read payload to compress, %v", err)
+ }
+
+ if _, err = w.Write(inBytes); err != nil {
+ return nil, fmt.Errorf("failed to write payload to be compressed, %v", err)
+ }
+ if err = w.Close(); err != nil {
+ return nil, fmt.Errorf("failed to flush payload being compressed, %v", err)
+ }
+
+ return b.Bytes(), nil
+}
diff --git a/vendor/github.com/aws/smithy-go/private/requestcompression/middleware_capture_request_compression.go b/vendor/github.com/aws/smithy-go/private/requestcompression/middleware_capture_request_compression.go
new file mode 100644
index 0000000..06c16af
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/private/requestcompression/middleware_capture_request_compression.go
@@ -0,0 +1,52 @@
+package requestcompression
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+ "io"
+ "net/http"
+)
+
+const captureUncompressedRequestID = "CaptureUncompressedRequest"
+
+// AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check
+func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error {
+ return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{
+ buf: buf,
+ }, "RequestCompression", middleware.Before)
+}
+
+type captureUncompressedRequestMiddleware struct {
+ req *http.Request
+ buf *bytes.Buffer
+ bytes []byte
+}
+
+// ID returns id of the captureUncompressedRequestMiddleware
+func (*captureUncompressedRequestMiddleware) ID() string {
+ return captureUncompressedRequestID
+}
+
+// HandleSerialize captures request payload before it is compressed by request compression middleware
+func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
+) (
+ output middleware.SerializeOutput, metadata middleware.Metadata, err error,
+) {
+ request, ok := input.Request.(*smithyhttp.Request)
+ if !ok {
+ return output, metadata, fmt.Errorf("error when retrieving http request")
+ }
+
+ _, err = io.Copy(m.buf, request.GetStream())
+ if err != nil {
+ return output, metadata, fmt.Errorf("error when copying http request stream: %q", err)
+ }
+ if err = request.RewindStream(); err != nil {
+ return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err)
+ }
+
+ return next.HandleSerialize(ctx, input)
+}
diff --git a/vendor/github.com/aws/smithy-go/private/requestcompression/request_compression.go b/vendor/github.com/aws/smithy-go/private/requestcompression/request_compression.go
new file mode 100644
index 0000000..7c41476
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/private/requestcompression/request_compression.go
@@ -0,0 +1,103 @@
+// Package requestcompression implements runtime support for smithy-modeled
+// request compression.
+//
+// This package is designated as private and is intended for use only by the
+// smithy client runtime. The exported API therein is not considered stable and
+// is subject to breaking changes without notice.
+package requestcompression
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "github.com/aws/smithy-go/middleware"
+ "github.com/aws/smithy-go/transport/http"
+ "io"
+)
+
+const MaxRequestMinCompressSizeBytes = 10485760
+
+// Enumeration values for supported compress Algorithms.
+const (
+ GZIP = "gzip"
+)
+
+type compressFunc func(io.Reader) ([]byte, error)
+
+var allowedAlgorithms = map[string]compressFunc{
+ GZIP: gzipCompress,
+}
+
+// AddRequestCompression add requestCompression middleware to op stack
+func AddRequestCompression(stack *middleware.Stack, disabled bool, minBytes int64, algorithms []string) error {
+ return stack.Serialize.Add(&requestCompression{
+ disableRequestCompression: disabled,
+ requestMinCompressSizeBytes: minBytes,
+ compressAlgorithms: algorithms,
+ }, middleware.After)
+}
+
+type requestCompression struct {
+ disableRequestCompression bool
+ requestMinCompressSizeBytes int64
+ compressAlgorithms []string
+}
+
+// ID returns the ID of the middleware
+func (m requestCompression) ID() string {
+ return "RequestCompression"
+}
+
+// HandleSerialize gzip compress the request's stream/body if enabled by config fields
+func (m requestCompression) HandleSerialize(
+ ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
+) (
+ out middleware.SerializeOutput, metadata middleware.Metadata, err error,
+) {
+ if m.disableRequestCompression {
+ return next.HandleSerialize(ctx, in)
+ }
+ // still need to check requestMinCompressSizeBytes in case it is out of range after service client config
+ if m.requestMinCompressSizeBytes < 0 || m.requestMinCompressSizeBytes > MaxRequestMinCompressSizeBytes {
+ return out, metadata, fmt.Errorf("invalid range for min request compression size bytes %d, must be within 0 and 10485760 inclusively", m.requestMinCompressSizeBytes)
+ }
+
+ req, ok := in.Request.(*http.Request)
+ if !ok {
+ return out, metadata, fmt.Errorf("unknown request type %T", req)
+ }
+
+ for _, algorithm := range m.compressAlgorithms {
+ compressFunc := allowedAlgorithms[algorithm]
+ if compressFunc != nil {
+ if stream := req.GetStream(); stream != nil {
+ size, found, err := req.StreamLength()
+ if err != nil {
+ return out, metadata, fmt.Errorf("error while finding request stream length, %v", err)
+ } else if !found || size < m.requestMinCompressSizeBytes {
+ return next.HandleSerialize(ctx, in)
+ }
+
+ compressedBytes, err := compressFunc(stream)
+ if err != nil {
+ return out, metadata, fmt.Errorf("failed to compress request stream, %v", err)
+ }
+
+ var newReq *http.Request
+ if newReq, err = req.SetStream(bytes.NewReader(compressedBytes)); err != nil {
+ return out, metadata, fmt.Errorf("failed to set request stream, %v", err)
+ }
+ *req = *newReq
+
+ if val := req.Header.Get("Content-Encoding"); val != "" {
+ req.Header.Set("Content-Encoding", fmt.Sprintf("%s, %s", val, algorithm))
+ } else {
+ req.Header.Set("Content-Encoding", algorithm)
+ }
+ }
+ break
+ }
+ }
+
+ return next.HandleSerialize(ctx, in)
+}