diff --git a/paginate/paginate.go b/paginate/paginate.go new file mode 100644 index 0000000..8250e58 --- /dev/null +++ b/paginate/paginate.go @@ -0,0 +1,159 @@ +package paginate + +import ( + "context" + "errors" + "net/http" + "strconv" + + "go.pkg.cx/middleware" +) + +// Errors +var ( + ErrPaginationDefaults = errors.New("pagination defaults are nil") +) + +// Context keys +var ( + PaginationCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/paginate", Name: "Pagination"} +) + +// DefaultOptions represents default paginate middleware options +var DefaultOptions = Options( + WithFindPaginationFn(PaginationFromQuery("page", "pageSize")), + SetPaginationDefaults(1, 50), + SetValidatePaginationFn(allowAll), + SetResponseHandler(RespondWithBadRequest), +) + +// Pagination represents pagination info +type Pagination struct { + Page int + PageSize int +} + +// Options turns a list of option instances into an option +func Options(opts ...Option) Option { + return func(p *paginate) { + for _, opt := range opts { + opt(p) + } + } +} + +// Option configures paginate middleware +type Option func(p *paginate) + +// WithFindPaginationFn adds pagination find function to the list +func WithFindPaginationFn(fn func(r *http.Request, p *Pagination) *Pagination) Option { + return func(p *paginate) { + p.findPaginationFns = append(p.findPaginationFns, fn) + } +} + +// SetFindPaginationFns sets pagination find functions list +func SetFindPaginationFns(fns ...func(r *http.Request, p *Pagination) *Pagination) Option { + return func(p *paginate) { + p.findPaginationFns = fns + } +} + +// SetPaginationDefaults sets pagination defaults function +func SetPaginationDefaults(page int, pageSize int) Option { + return func(p *paginate) { + p.paginationDefaultsFn = func() *Pagination { + return &Pagination{Page: page, PageSize: pageSize} + } + } +} + +// SetValidatePaginationFn sets pagination validation function +func SetValidatePaginationFn(fn func(p *Pagination) error) Option { + return func(p *paginate) { + p.validatePaginationFn = fn + } +} + +// SetResponseHandler sets response handler +func SetResponseHandler(fn middleware.ResponseHandle) Option { + return func(p *paginate) { + p.responseHandler = fn + } +} + +type paginate struct { + findPaginationFns []func(r *http.Request, p *Pagination) *Pagination + paginationDefaultsFn func() *Pagination + validatePaginationFn func(p *Pagination) error + responseHandler middleware.ResponseHandle +} + +// Middleware returns paginate middleware +func Middleware(opts ...Option) func(next http.Handler) http.Handler { + p := &paginate{} + opts = append([]Option{DefaultOptions}, opts...) + + for _, opt := range opts { + opt(p) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pagination := p.paginationDefaultsFn() + if pagination == nil { + p.responseHandler(w, r, ErrPaginationDefaults) + return + } + + for _, fn := range p.findPaginationFns { + if nextPagination := fn(r, pagination); nextPagination != nil { + pagination = nextPagination + } + } + + if err := p.validatePaginationFn(pagination); err != nil { + p.responseHandler(w, r, err) + return + } + + ctx := r.Context() + ctx = context.WithValue(ctx, PaginationCtxKey, pagination) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// PaginationFromContext returns pagination from context +func PaginationFromContext(ctx context.Context) *Pagination { + if pagination, ok := ctx.Value(PaginationCtxKey).(*Pagination); ok { + return pagination + } + + return nil +} + +// PaginationFromQuery returns pagination from query params +func PaginationFromQuery(pageParam string, pageSizeParam string) func(r *http.Request, p *Pagination) *Pagination { + return func(r *http.Request, p *Pagination) *Pagination { + if page, err := strconv.Atoi(r.URL.Query().Get(pageParam)); err == nil { + p.Page = page + } + + if pageSize, err := strconv.Atoi(r.URL.Query().Get(pageSizeParam)); err == nil { + p.PageSize = pageSize + } + + return p + } +} + +// RespondWithBadRequest is a default response handler +func RespondWithBadRequest(w http.ResponseWriter, _ *http.Request, _ error) { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) +} + +func allowAll(_ *Pagination) error { + return nil +} diff --git a/paginate/paginate_test.go b/paginate/paginate_test.go new file mode 100644 index 0000000..32cd43b --- /dev/null +++ b/paginate/paginate_test.go @@ -0,0 +1,68 @@ +package paginate + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPaginationFromQuery(t *testing.T) { + req, err := http.NewRequest("GET", "/?page=2&pageSize=10", http.NoBody) + assert.NoError(t, err) + + pagination := PaginationFromQuery("page", "pageSize")(req, &Pagination{}) + assert.Equal(t, &Pagination{Page: 2, PageSize: 10}, pagination) +} + +func TestValidPagination(t *testing.T) { + opts := []Option{ + SetPaginationDefaults(1, 10), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pagination := PaginationFromContext(r.Context()) + assert.NotNil(t, pagination) + assert.Equal(t, 2, pagination.Page) + assert.Equal(t, 10, pagination.PageSize) + }) + + middleware := Middleware(opts...)(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/?page=2&pageSize=10", http.NoBody) + middleware.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + +} + +func TestInvalidPagination(t *testing.T) { + opts := []Option{ + SetPaginationDefaults(1, 10), + SetValidatePaginationFn(func(p *Pagination) error { + if p.Page < 1 || p.PageSize < 1 { + return errors.New("invalid pagination") + } + + return nil + }), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pagination := PaginationFromContext(r.Context()) + assert.NotNil(t, pagination) + assert.Equal(t, -1, pagination.Page) + assert.Equal(t, 10, pagination.PageSize) + }) + + middleware := Middleware(opts...)(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/?page=-1&pageSize=10", http.NoBody) + middleware.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +}