package paginate import ( "context" "errors" "net/http" "strconv" "go.pkg.cx/middleware" ) // Errors var ( ErrPaginationDefaults = errors.New("pagination defaults are nil") ) // Context keys var ( PaginationCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/paginate", Name: "Pagination"} ) // DefaultOptions represents default paginate middleware options var DefaultOptions = Options( WithFindPaginationFn(PaginationFromQuery("page", "pageSize")), SetPaginationDefaults(1, 50), SetValidatePaginationFn(allowAll), SetResponseHandler(RespondWithBadRequest), ) // Pagination represents pagination info type Pagination struct { Page int PageSize int } // Options turns a list of option instances into an option func Options(opts ...Option) Option { return func(p *paginate) { for _, opt := range opts { opt(p) } } } // Option configures paginate middleware type Option func(p *paginate) // WithFindPaginationFn adds pagination find function to the list func WithFindPaginationFn(fn func(r *http.Request, p *Pagination) *Pagination) Option { return func(p *paginate) { p.findPaginationFns = append(p.findPaginationFns, fn) } } // SetFindPaginationFns sets pagination find functions list func SetFindPaginationFns(fns ...func(r *http.Request, p *Pagination) *Pagination) Option { return func(p *paginate) { p.findPaginationFns = fns } } // SetPaginationDefaults sets pagination defaults function func SetPaginationDefaults(page int, pageSize int) Option { return func(p *paginate) { p.paginationDefaultsFn = func() *Pagination { return &Pagination{Page: page, PageSize: pageSize} } } } // SetValidatePaginationFn sets pagination validation function func SetValidatePaginationFn(fn func(p *Pagination) error) Option { return func(p *paginate) { p.validatePaginationFn = fn } } // SetResponseHandler sets response handler func SetResponseHandler(fn middleware.ResponseHandle) Option { return func(p *paginate) { p.responseHandler = fn } } type paginate struct { findPaginationFns []func(r *http.Request, p *Pagination) *Pagination paginationDefaultsFn func() *Pagination validatePaginationFn func(p *Pagination) error responseHandler middleware.ResponseHandle } // Middleware returns paginate middleware func Middleware(opts ...Option) func(next http.Handler) http.Handler { p := &paginate{} opts = append([]Option{DefaultOptions}, opts...) for _, opt := range opts { opt(p) } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pagination := p.paginationDefaultsFn() if pagination == nil { p.responseHandler(w, r, ErrPaginationDefaults) return } for _, fn := range p.findPaginationFns { if nextPagination := fn(r, pagination); nextPagination != nil { pagination = nextPagination } } if err := p.validatePaginationFn(pagination); err != nil { p.responseHandler(w, r, err) return } ctx := r.Context() ctx = context.WithValue(ctx, PaginationCtxKey, pagination) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // PaginationFromContext returns pagination from context func PaginationFromContext(ctx context.Context) *Pagination { if pagination, ok := ctx.Value(PaginationCtxKey).(*Pagination); ok { return pagination } return nil } // PaginationFromQuery returns pagination from query params func PaginationFromQuery(pageParam string, pageSizeParam string) func(r *http.Request, p *Pagination) *Pagination { return func(r *http.Request, p *Pagination) *Pagination { if page, err := strconv.Atoi(r.URL.Query().Get(pageParam)); err == nil { p.Page = page } if pageSize, err := strconv.Atoi(r.URL.Query().Get(pageSizeParam)); err == nil { p.PageSize = pageSize } return p } } // RespondWithBadRequest is a default response handler func RespondWithBadRequest(w http.ResponseWriter, _ *http.Request, _ error) { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) } func allowAll(_ *Pagination) error { return nil }