Skip to content

Kratos 中间件设计模式

概述

Kratos 是 bilibili 开源的一套轻量级微服务框架,其中间件设计采用洋葱模型(Onion Model),提供了一套标准化的中间件接口和丰富的内置中间件。在 Ledger 模块中,我们采用 Kratos 的中间件架构来实现限流、熔断、链路追踪、监控等横切关注点。

核心特性

  • 洋葱模型: 请求和响应都会经过中间件链
  • 类型安全: 基于泛型的强类型中间件接口
  • 可组合: 支持中间件的灵活组合和排序
  • 高性能: 最小化的性能开销
  • 可扩展: 易于扩展自定义中间件

应用场景

  • API 限流和熔断保护
  • 请求认证和授权
  • 链路追踪和性能监控
  • 请求日志和审计
  • 错误处理和恢复
  • 数据验证和转换

核心原理

中间件接口设计

Kratos 中间件基于函数式编程理念,采用高阶函数模式:

go
// 中间件函数签名
type Middleware func(Handler) Handler

// 处理器接口
type Handler func(ctx context.Context, req interface{}) (interface{}, error)

洋葱模型执行流程

中间件链构建

go
// 中间件链构建
func Chain(middlewares ...Middleware) Middleware {
    return func(handler Handler) Handler {
        // 从后往前包装
        for i := len(middlewares) - 1; i >= 0; i-- {
            handler = middlewares[i](handler)
        }
        return handler
    }
}

// 使用示例
handler := Chain(
    Recovery(),          // 最外层:异常恢复
    Logging(logger),     // 请求日志
    Metrics(),          // 性能指标
    RateLimit(limiter), // 限流控制
    Auth(authFunc),     // 身份认证
)(businessHandler)

Ledger 模块中间件架构

1. 中间件执行顺序设计

2. 中间件配置管理

go
// 中间件配置结构
type MiddlewareConfig struct {
    // 异常恢复配置
    Recovery RecoveryConfig `yaml:"recovery"`
    
    // 日志配置
    Logging LoggingConfig `yaml:"logging"`
    
    // 链路追踪配置
    Tracing TracingConfig `yaml:"tracing"`
    
    // 指标收集配置
    Metrics MetricsConfig `yaml:"metrics"`
    
    // 限流配置
    RateLimit RateLimitConfig `yaml:"rate_limit"`
    
    // 熔断配置
    CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"`
    
    // 超时配置
    Timeout TimeoutConfig `yaml:"timeout"`
}

// 限流配置
type RateLimitConfig struct {
    Enable    bool              `yaml:"enable"`
    Algorithm string            `yaml:"algorithm"` // token_bucket, sliding_window
    Global    GlobalLimitConfig `yaml:"global"`
    PerUser   PerUserLimitConfig `yaml:"per_user"`
    PerAPI    map[string]APILimitConfig `yaml:"per_api"`
}

type GlobalLimitConfig struct {
    QPS   int `yaml:"qps"`   // 全局 QPS 限制
    Burst int `yaml:"burst"` // 突发容量
}

type PerUserLimitConfig struct {
    QPS             int           `yaml:"qps"`
    Burst           int           `yaml:"burst"`
    WindowSize      time.Duration `yaml:"window_size"`
    SlidingLogSize  int           `yaml:"sliding_log_size"`
}

// 熔断配置
type CircuitBreakerConfig struct {
    Enable              bool          `yaml:"enable"`
    FailureThreshold    int           `yaml:"failure_threshold"`    // 失败阈值
    SuccessThreshold    int           `yaml:"success_threshold"`    // 成功阈值
    Timeout             time.Duration `yaml:"timeout"`              // 超时时间
    MaxRequests         int           `yaml:"max_requests"`         // 半开状态最大请求数
    Interval            time.Duration `yaml:"interval"`             // 统计间隔
    MinRequestThreshold int           `yaml:"min_request_threshold"` // 最小请求阈值
}

核心中间件实现

1. 限流中间件 (Rate Limiter)

go
package middleware

import (
    "context"
    "fmt"
    "net/http"
    "strconv"
    "time"
    
    "github.com/gin-gonic/gin"
    "golang.org/x/time/rate"
    
    "github.com/FixIterate/lz-stash/internal/pkg/ratelimit"
)

// RateLimiterMiddleware 限流中间件
type RateLimiterMiddleware struct {
    globalLimiter *rate.Limiter
    userLimiters  *ratelimit.UserLimiterManager
    apiLimiters   map[string]*rate.Limiter
    config        *RateLimitConfig
}

// NewRateLimiterMiddleware 创建限流中间件
func NewRateLimiterMiddleware(config *RateLimitConfig) *RateLimiterMiddleware {
    middleware := &RateLimiterMiddleware{
        config:       config,
        apiLimiters:  make(map[string]*rate.Limiter),
        userLimiters: ratelimit.NewUserLimiterManager(config.PerUser),
    }
    
    // 初始化全局限流器
    if config.Global.QPS > 0 {
        middleware.globalLimiter = rate.NewLimiter(
            rate.Limit(config.Global.QPS),
            config.Global.Burst,
        )
    }
    
    // 初始化 API 级限流器
    for path, apiConfig := range config.PerAPI {
        if apiConfig.QPS > 0 {
            middleware.apiLimiters[path] = rate.NewLimiter(
                rate.Limit(apiConfig.QPS),
                apiConfig.Burst,
            )
        }
    }
    
    return middleware
}

// Middleware 返回 Gin 中间件函数
func (m *RateLimiterMiddleware) Middleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        if !m.config.Enable {
            c.Next()
            return
        }
        
        // 1. 全局限流检查
        if m.globalLimiter != nil && !m.globalLimiter.Allow() {
            m.handleRateLimitExceeded(c, "global", m.config.Global.QPS)
            return
        }
        
        // 2. API 级限流检查
        path := c.FullPath()
        if apiLimiter, exists := m.apiLimiters[path]; exists {
            if !apiLimiter.Allow() {
                apiConfig := m.config.PerAPI[path]
                m.handleRateLimitExceeded(c, "api", apiConfig.QPS)
                return
            }
        }
        
        // 3. 用户级限流检查
        userID := m.extractUserID(c)
        if userID != "" {
            allowed, retryAfter := m.userLimiters.Allow(userID)
            if !allowed {
                m.handleUserRateLimitExceeded(c, userID, retryAfter)
                return
            }
        }
        
        // 4. 设置限流相关响应头
        m.setRateLimitHeaders(c, userID)
        
        c.Next()
    }
}

// 处理限流超过情况
func (m *RateLimiterMiddleware) handleRateLimitExceeded(c *gin.Context, limitType string, qps int) {
    c.Header("X-RateLimit-Limit", strconv.Itoa(qps))
    c.Header("X-RateLimit-Remaining", "0")
    c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Second).Unix(), 10))
    
    c.JSON(http.StatusTooManyRequests, gin.H{
        "error": "rate limit exceeded",
        "type":  limitType,
        "limit": qps,
        "retry_after": 1,
    })
    c.Abort()
}

// 处理用户级限流超过情况
func (m *RateLimiterMiddleware) handleUserRateLimitExceeded(c *gin.Context, userID string, retryAfter time.Duration) {
    c.Header("X-RateLimit-Limit", strconv.Itoa(m.config.PerUser.QPS))
    c.Header("X-RateLimit-Remaining", "0")
    c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(retryAfter).Unix(), 10))
    c.Header("Retry-After", strconv.Itoa(int(retryAfter.Seconds())))
    
    c.JSON(http.StatusTooManyRequests, gin.H{
        "error": "user rate limit exceeded",
        "user_id": userID,
        "retry_after": int(retryAfter.Seconds()),
    })
    c.Abort()
}

// 设置限流相关响应头
func (m *RateLimiterMiddleware) setRateLimitHeaders(c *gin.Context, userID string) {
    if userID != "" {
        remaining := m.userLimiters.GetRemaining(userID)
        c.Header("X-RateLimit-Limit", strconv.Itoa(m.config.PerUser.QPS))
        c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
        
        resetTime := m.userLimiters.GetResetTime(userID)
        c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime.Unix(), 10))
    }
}

// 提取用户 ID
func (m *RateLimiterMiddleware) extractUserID(c *gin.Context) string {
    // 优先从 JWT token 中提取
    if userID, exists := c.Get("user_id"); exists {
        if id, ok := userID.(string); ok {
            return id
        }
    }
    
    // 备选方案:从 header 中提取
    if userID := c.GetHeader("X-User-ID"); userID != "" {
        return userID
    }
    
    // 最后方案:使用 IP 地址
    return c.ClientIP()
}

2. 熔断中间件 (Circuit Breaker)

go
package middleware

import (
    "context"
    "fmt"
    "net/http"
    "sync"
    "time"
    
    "github.com/gin-gonic/gin"
    
    "github.com/FixIterate/lz-stash/internal/pkg/circuitbreaker"
)

// CircuitBreakerState 熔断器状态
type CircuitBreakerState int

const (
    StateClosed CircuitBreakerState = iota // 关闭状态
    StateOpen                              // 开启状态  
    StateHalfOpen                          // 半开状态
)

// CircuitBreakerMiddleware 熔断中间件
type CircuitBreakerMiddleware struct {
    breakers map[string]*CircuitBreaker
    config   *CircuitBreakerConfig
    mutex    sync.RWMutex
}

// CircuitBreaker 熔断器实现
type CircuitBreaker struct {
    name                string
    state               CircuitBreakerState
    failureCount        int
    successCount        int
    requestCount        int
    lastFailureTime     time.Time
    lastSuccessTime     time.Time
    config              *CircuitBreakerConfig
    mutex               sync.RWMutex
}

// NewCircuitBreakerMiddleware 创建熔断中间件
func NewCircuitBreakerMiddleware(config *CircuitBreakerConfig) *CircuitBreakerMiddleware {
    return &CircuitBreakerMiddleware{
        breakers: make(map[string]*CircuitBreaker),
        config:   config,
    }
}

// Middleware 返回 Gin 中间件函数
func (m *CircuitBreakerMiddleware) Middleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        if !m.config.Enable {
            c.Next()
            return
        }
        
        // 获取或创建熔断器
        breaker := m.getOrCreateBreaker(c.FullPath())
        
        // 检查熔断器状态
        if !breaker.AllowRequest() {
            m.handleCircuitBreakerOpen(c, breaker)
            return
        }
        
        // 记录请求开始时间
        startTime := time.Now()
        
        // 执行下一个中间件
        c.Next()
        
        // 根据响应状态更新熔断器
        duration := time.Since(startTime)
        statusCode := c.Writer.Status()
        
        if m.isSuccess(statusCode, duration) {
            breaker.RecordSuccess()
        } else {
            breaker.RecordFailure()
        }
    }
}

// getOrCreateBreaker 获取或创建熔断器
func (m *CircuitBreakerMiddleware) getOrCreateBreaker(path string) *CircuitBreaker {
    m.mutex.RLock()
    breaker, exists := m.breakers[path]
    m.mutex.RUnlock()
    
    if exists {
        return breaker
    }
    
    m.mutex.Lock()
    defer m.mutex.Unlock()
    
    // 双重检查
    if breaker, exists := m.breakers[path]; exists {
        return breaker
    }
    
    // 创建新的熔断器
    breaker = &CircuitBreaker{
        name:   path,
        state:  StateClosed,
        config: m.config,
    }
    
    m.breakers[path] = breaker
    return breaker
}

// AllowRequest 检查是否允许请求
func (cb *CircuitBreaker) AllowRequest() bool {
    cb.mutex.RLock()
    defer cb.mutex.RUnlock()
    
    switch cb.state {
    case StateClosed:
        return true
    case StateOpen:
        // 检查是否可以转为半开状态
        if time.Since(cb.lastFailureTime) > cb.config.Timeout {
            cb.state = StateHalfOpen
            cb.requestCount = 0
            return true
        }
        return false
    case StateHalfOpen:
        // 半开状态下限制请求数量
        return cb.requestCount < cb.config.MaxRequests
    }
    
    return false
}

// RecordSuccess 记录成功请求
func (cb *CircuitBreaker) RecordSuccess() {
    cb.mutex.Lock()
    defer cb.mutex.Unlock()
    
    cb.lastSuccessTime = time.Now()
    cb.requestCount++
    
    switch cb.state {
    case StateClosed:
        cb.failureCount = 0
    case StateHalfOpen:
        cb.successCount++
        // 达到成功阈值,转为关闭状态
        if cb.successCount >= cb.config.SuccessThreshold {
            cb.state = StateClosed
            cb.failureCount = 0
            cb.successCount = 0
            cb.requestCount = 0
        }
    }
}

// RecordFailure 记录失败请求
func (cb *CircuitBreaker) RecordFailure() {
    cb.mutex.Lock()
    defer cb.mutex.Unlock()
    
    cb.lastFailureTime = time.Now()
    cb.requestCount++
    cb.failureCount++
    
    switch cb.state {
    case StateClosed:
        // 检查是否需要开启熔断器
        if cb.failureCount >= cb.config.FailureThreshold &&
           cb.requestCount >= cb.config.MinRequestThreshold {
            cb.state = StateOpen
        }
    case StateHalfOpen:
        // 半开状态下出现失败,立即转为开启状态
        cb.state = StateOpen
        cb.successCount = 0
        cb.requestCount = 0
    }
}

// 处理熔断器开启状态
func (m *CircuitBreakerMiddleware) handleCircuitBreakerOpen(c *gin.Context, breaker *CircuitBreaker) {
    c.Header("X-Circuit-Breaker-State", "OPEN")
    
    retryAfter := int(m.config.Timeout.Seconds())
    c.Header("Retry-After", fmt.Sprintf("%d", retryAfter))
    
    c.JSON(http.StatusServiceUnavailable, gin.H{
        "error": "service temporarily unavailable",
        "reason": "circuit breaker is open",
        "retry_after": retryAfter,
        "path": breaker.name,
    })
    c.Abort()
}

// 判断请求是否成功
func (m *CircuitBreakerMiddleware) isSuccess(statusCode int, duration time.Duration) bool {
    // HTTP 状态码检查
    if statusCode >= 500 {
        return false
    }
    
    // 超时检查
    if duration > m.config.Timeout {
        return false
    }
    
    return true
}

// GetBreakerState 获取熔断器状态 (用于监控)
func (m *CircuitBreakerMiddleware) GetBreakerState(path string) map[string]interface{} {
    m.mutex.RLock()
    breaker, exists := m.breakers[path]
    m.mutex.RUnlock()
    
    if !exists {
        return map[string]interface{}{
            "exists": false,
        }
    }
    
    breaker.mutex.RLock()
    defer breaker.mutex.RUnlock()
    
    return map[string]interface{}{
        "exists":           true,
        "name":            breaker.name,
        "state":           m.stateToString(breaker.state),
        "failure_count":   breaker.failureCount,
        "success_count":   breaker.successCount,
        "request_count":   breaker.requestCount,
        "last_failure":    breaker.lastFailureTime,
        "last_success":    breaker.lastSuccessTime,
    }
}

func (m *CircuitBreakerMiddleware) stateToString(state CircuitBreakerState) string {
    switch state {
    case StateClosed:
        return "CLOSED"
    case StateOpen:
        return "OPEN"
    case StateHalfOpen:
        return "HALF_OPEN"
    default:
        return "UNKNOWN"
    }
}

3. 链路追踪中间件 (Tracing)

go
package middleware

import (
    "context"
    "fmt"
    
    "github.com/gin-gonic/gin"
    "github.com/opentracing/opentracing-go"
    "github.com/opentracing/opentracing-go/ext"
    "github.com/opentracing/opentracing-go/log"
    "github.com/uber/jaeger-client-go"
)

// TracingMiddleware 链路追踪中间件
type TracingMiddleware struct {
    tracer opentracing.Tracer
    config *TracingConfig
}

// TracingConfig 追踪配置
type TracingConfig struct {
    Enable      bool   `yaml:"enable"`
    ServiceName string `yaml:"service_name"`
    SamplerType string `yaml:"sampler_type"` // const, probabilistic, ratelimiting
    SamplerParam float64 `yaml:"sampler_param"`
    JaegerEndpoint string `yaml:"jaeger_endpoint"`
}

// NewTracingMiddleware 创建链路追踪中间件
func NewTracingMiddleware(config *TracingConfig) (*TracingMiddleware, error) {
    if !config.Enable {
        return &TracingMiddleware{config: config}, nil
    }
    
    // 创建 Jaeger tracer
    tracer, closer, err := createJaegerTracer(config)
    if err != nil {
        return nil, fmt.Errorf("failed to create tracer: %w", err)
    }
    
    // 设置全局 tracer
    opentracing.SetGlobalTracer(tracer)
    
    return &TracingMiddleware{
        tracer: tracer,
        config: config,
    }, nil
}

// Middleware 返回 Gin 中间件函数
func (m *TracingMiddleware) Middleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        if !m.config.Enable {
            c.Next()
            return
        }
        
        // 构建 span 名称
        spanName := fmt.Sprintf("%s %s", c.Request.Method, c.FullPath())
        
        // 尝试从请求头中提取 parent span context
        spanCtx, err := m.tracer.Extract(
            opentracing.HTTPHeaders,
            opentracing.HTTPHeadersCarrier(c.Request.Header),
        )
        
        var span opentracing.Span
        if err != nil {
            // 创建根 span
            span = m.tracer.StartSpan(spanName)
        } else {
            // 创建子 span
            span = m.tracer.StartSpan(spanName, opentracing.ChildOf(spanCtx))
        }
        defer span.Finish()
        
        // 设置 span 标签
        ext.HTTPMethod.Set(span, c.Request.Method)
        ext.HTTPUrl.Set(span, c.Request.URL.String())
        ext.Component.Set(span, "gin-http")
        
        // 设置自定义标签
        span.SetTag("service.name", m.config.ServiceName)
        span.SetTag("http.path", c.FullPath())
        span.SetTag("http.route", c.FullPath())
        
        // 获取客户端信息
        if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
            span.SetTag("http.user_agent", userAgent)
        }
        if clientIP := c.ClientIP(); clientIP != "" {
            span.SetTag("http.client_ip", clientIP)
        }
        
        // 将 span 注入到 context 中
        ctx := opentracing.ContextWithSpan(c.Request.Context(), span)
        c.Request = c.Request.WithContext(ctx)
        
        // 将 trace ID 和 span ID 添加到响应头 (用于日志关联)
        if spanCtx := span.Context(); spanCtx != nil {
            if jaegerCtx, ok := spanCtx.(jaeger.SpanContext); ok {
                c.Header("X-Trace-ID", jaegerCtx.TraceID().String())
                c.Header("X-Span-ID", jaegerCtx.SpanID().String())
                
                // 将 trace ID 存储到 context 中供其他中间件使用
                c.Set("trace_id", jaegerCtx.TraceID().String())
                c.Set("span_id", jaegerCtx.SpanID().String())
            }
        }
        
        // 执行下一个中间件
        c.Next()
        
        // 设置响应相关的标签和日志
        statusCode := c.Writer.Status()
        ext.HTTPStatusCode.Set(span, uint16(statusCode))
        
        if statusCode >= 400 {
            ext.Error.Set(span, true)
            if statusCode >= 500 {
                span.LogFields(
                    log.String("event", "error"),
                    log.String("level", "error"),
                    log.Int("status_code", statusCode),
                )
            }
        }
        
        // 记录响应大小
        responseSize := c.Writer.Size()
        if responseSize > 0 {
            span.SetTag("http.response_size", responseSize)
        }
        
        // 记录错误信息
        if len(c.Errors) > 0 {
            span.LogFields(
                log.String("event", "error"),
                log.String("message", c.Errors.String()),
            )
        }
    }
}

// createJaegerTracer 创建 Jaeger tracer
func createJaegerTracer(config *TracingConfig) (opentracing.Tracer, io.Closer, error) {
    cfg := jaegerconfig.Configuration{
        ServiceName: config.ServiceName,
        Sampler: &jaegerconfig.SamplerConfig{
            Type:  config.SamplerType,
            Param: config.SamplerParam,
        },
        Reporter: &jaegerconfig.ReporterConfig{
            LogSpans:           true,
            LocalAgentHostPort: config.JaegerEndpoint,
        },
    }
    
    tracer, closer, err := cfg.NewTracer(
        jaegerconfig.Logger(jaeger.StdLogger),
    )
    if err != nil {
        return nil, nil, err
    }
    
    return tracer, closer, nil
}

// StartSpanFromContext 从 context 中启动新的 span
func StartSpanFromContext(ctx context.Context, operationName string) (opentracing.Span, context.Context) {
    span, ctx := opentracing.StartSpanFromContext(ctx, operationName)
    return span, ctx
}

// LogError 记录错误到 span
func LogError(span opentracing.Span, err error) {
    if span == nil || err == nil {
        return
    }
    
    ext.Error.Set(span, true)
    span.LogFields(
        log.String("event", "error"),
        log.String("message", err.Error()),
        log.String("error.object", fmt.Sprintf("%+v", err)),
    )
}

// AddSpanTags 添加多个标签到 span
func AddSpanTags(span opentracing.Span, tags map[string]interface{}) {
    if span == nil {
        return
    }
    
    for key, value := range tags {
        span.SetTag(key, value)
    }
}

4. 性能监控中间件 (Metrics)

go
package middleware

import (
    "strconv"
    "time"
    
    "github.com/gin-gonic/gin"
    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

// MetricsMiddleware 性能监控中间件
type MetricsMiddleware struct {
    // HTTP 请求计数
    httpRequestsTotal *prometheus.CounterVec
    
    // HTTP 请求延迟
    httpRequestDuration *prometheus.HistogramVec
    
    // HTTP 请求大小
    httpRequestSize *prometheus.HistogramVec
    
    // HTTP 响应大小
    httpResponseSize *prometheus.HistogramVec
    
    // 当前处理中的请求数
    httpRequestsInFlight prometheus.Gauge
    
    config *MetricsConfig
}

// MetricsConfig 监控配置
type MetricsConfig struct {
    Enable    bool     `yaml:"enable"`
    Namespace string   `yaml:"namespace"`
    Subsystem string   `yaml:"subsystem"`
    Buckets   []float64 `yaml:"buckets"`
}

// NewMetricsMiddleware 创建性能监控中间件
func NewMetricsMiddleware(config *MetricsConfig) *MetricsMiddleware {
    if !config.Enable {
        return &MetricsMiddleware{config: config}
    }
    
    // 默认延迟桶
    buckets := config.Buckets
    if len(buckets) == 0 {
        buckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
    }
    
    return &MetricsMiddleware{
        // HTTP 请求总数
        httpRequestsTotal: promauto.NewCounterVec(
            prometheus.CounterOpts{
                Namespace: config.Namespace,
                Subsystem: config.Subsystem,
                Name:      "http_requests_total",
                Help:      "Total number of HTTP requests",
            },
            []string{"method", "path", "status_code"},
        ),
        
        // HTTP 请求延迟分布
        httpRequestDuration: promauto.NewHistogramVec(
            prometheus.HistogramOpts{
                Namespace: config.Namespace,
                Subsystem: config.Subsystem,
                Name:      "http_request_duration_seconds",
                Help:      "HTTP request duration in seconds",
                Buckets:   buckets,
            },
            []string{"method", "path", "status_code"},
        ),
        
        // HTTP 请求大小分布
        httpRequestSize: promauto.NewHistogramVec(
            prometheus.HistogramOpts{
                Namespace: config.Namespace,
                Subsystem: config.Subsystem,
                Name:      "http_request_size_bytes",
                Help:      "HTTP request size in bytes",
                Buckets:   prometheus.ExponentialBuckets(1024, 2, 10), // 1KB, 2KB, 4KB, ...
            },
            []string{"method", "path"},
        ),
        
        // HTTP 响应大小分布
        httpResponseSize: promauto.NewHistogramVec(
            prometheus.HistogramOpts{
                Namespace: config.Namespace,
                Subsystem: config.Subsystem,
                Name:      "http_response_size_bytes",
                Help:      "HTTP response size in bytes",
                Buckets:   prometheus.ExponentialBuckets(1024, 2, 10),
            },
            []string{"method", "path", "status_code"},
        ),
        
        // 当前处理中的请求数
        httpRequestsInFlight: promauto.NewGauge(
            prometheus.GaugeOpts{
                Namespace: config.Namespace,
                Subsystem: config.Subsystem,
                Name:      "http_requests_in_flight",
                Help:      "Current number of HTTP requests being processed",
            },
        ),
        
        config: config,
    }
}

// Middleware 返回 Gin 中间件函数
func (m *MetricsMiddleware) Middleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        if !m.config.Enable {
            c.Next()
            return
        }
        
        // 记录开始时间
        startTime := time.Now()
        
        // 增加处理中的请求计数
        m.httpRequestsInFlight.Inc()
        defer m.httpRequestsInFlight.Dec()
        
        // 记录请求大小
        requestSize := computeRequestSize(c.Request)
        
        // 执行下一个中间件
        c.Next()
        
        // 计算处理时间
        duration := time.Since(startTime)
        
        // 获取请求信息
        method := c.Request.Method
        path := c.FullPath()
        statusCode := strconv.Itoa(c.Writer.Status())
        
        // 记录指标
        m.httpRequestsTotal.WithLabelValues(method, path, statusCode).Inc()
        m.httpRequestDuration.WithLabelValues(method, path, statusCode).Observe(duration.Seconds())
        m.httpRequestSize.WithLabelValues(method, path).Observe(float64(requestSize))
        
        // 记录响应大小
        responseSize := c.Writer.Size()
        if responseSize > 0 {
            m.httpResponseSize.WithLabelValues(method, path, statusCode).Observe(float64(responseSize))
        }
    }
}

// computeRequestSize 计算请求大小
func computeRequestSize(r *http.Request) int64 {
    size := int64(0)
    
    // URL 长度
    if r.URL != nil {
        size += int64(len(r.URL.String()))
    }
    
    // Headers 大小
    for name, values := range r.Header {
        size += int64(len(name))
        for _, value := range values {
            size += int64(len(value))
        }
    }
    
    // Body 大小
    if r.ContentLength > 0 {
        size += r.ContentLength
    }
    
    return size
}

// RegisterCustomMetrics 注册自定义业务指标
func (m *MetricsMiddleware) RegisterCustomMetrics() {
    // 注册业务相关的指标
    promauto.NewCounterVec(
        prometheus.CounterOpts{
            Namespace: m.config.Namespace,
            Subsystem: "business",
            Name:      "transactions_total",
            Help:      "Total number of transactions created",
        },
        []string{"ledger_id", "transaction_type"},
    )
    
    promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Namespace: m.config.Namespace,
            Subsystem: "business",
            Name:      "active_users",
            Help:      "Current number of active users",
        },
        []string{"time_window"},
    )
    
    promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Namespace: m.config.Namespace,
            Subsystem: "business",
            Name:      "transaction_amount",
            Help:      "Transaction amount distribution",
            Buckets:   []float64{1, 10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000},
        },
        []string{"currency", "transaction_type"},
    )
}

实现方案

1. 中间件注册和配置

go
package server

import (
    "github.com/gin-gonic/gin"
    
    "github.com/FixIterate/lz-stash/internal/middleware"
    "github.com/FixIterate/lz-stash/internal/pkg/config"
)

// HTTPServer HTTP 服务器
type HTTPServer struct {
    engine *gin.Engine
    config *config.HTTPConfig
}

// NewHTTPServer 创建 HTTP 服务器
func NewHTTPServer(cfg *config.Config) *HTTPServer {
    // 设置 Gin 模式
    if cfg.App.Debug {
        gin.SetMode(gin.DebugMode)
    } else {
        gin.SetMode(gin.ReleaseMode)
    }
    
    engine := gin.New()
    
    server := &HTTPServer{
        engine: engine,
        config: &cfg.HTTP,
    }
    
    // 注册中间件
    server.setupMiddlewares(cfg)
    
    return server
}

// setupMiddlewares 设置中间件
func (s *HTTPServer) setupMiddlewares(cfg *config.Config) {
    // 1. 异常恢复中间件 (最外层)
    if cfg.Middleware.Recovery.Enable {
        recoveryMiddleware := middleware.NewRecoveryMiddleware(&cfg.Middleware.Recovery)
        s.engine.Use(recoveryMiddleware.Middleware())
    }
    
    // 2. 请求日志中间件
    if cfg.Middleware.Logging.Enable {
        loggingMiddleware := middleware.NewLoggingMiddleware(&cfg.Middleware.Logging)
        s.engine.Use(loggingMiddleware.Middleware())
    }
    
    // 3. 链路追踪中间件
    if cfg.Middleware.Tracing.Enable {
        tracingMiddleware, err := middleware.NewTracingMiddleware(&cfg.Middleware.Tracing)
        if err != nil {
            panic(fmt.Sprintf("failed to create tracing middleware: %v", err))
        }
        s.engine.Use(tracingMiddleware.Middleware())
    }
    
    // 4. 性能监控中间件
    if cfg.Middleware.Metrics.Enable {
        metricsMiddleware := middleware.NewMetricsMiddleware(&cfg.Middleware.Metrics)
        s.engine.Use(metricsMiddleware.Middleware())
        
        // 注册业务指标
        metricsMiddleware.RegisterCustomMetrics()
    }
    
    // 5. CORS 中间件
    if cfg.Middleware.CORS.Enable {
        corsMiddleware := middleware.NewCORSMiddleware(&cfg.Middleware.CORS)
        s.engine.Use(corsMiddleware.Middleware())
    }
    
    // 6. 请求验证中间件
    if cfg.Middleware.Validation.Enable {
        validationMiddleware := middleware.NewValidationMiddleware(&cfg.Middleware.Validation)
        s.engine.Use(validationMiddleware.Middleware())
    }
    
    // 7. 身份认证中间件
    if cfg.Middleware.Auth.Enable {
        authMiddleware := middleware.NewAuthMiddleware(&cfg.Middleware.Auth)
        s.engine.Use(authMiddleware.Middleware())
    }
    
    // 8. 限流中间件
    if cfg.Middleware.RateLimit.Enable {
        rateLimitMiddleware := middleware.NewRateLimiterMiddleware(&cfg.Middleware.RateLimit)
        s.engine.Use(rateLimitMiddleware.Middleware())
    }
    
    // 9. 熔断中间件
    if cfg.Middleware.CircuitBreaker.Enable {
        circuitBreakerMiddleware := middleware.NewCircuitBreakerMiddleware(&cfg.Middleware.CircuitBreaker)
        s.engine.Use(circuitBreakerMiddleware.Middleware())
    }
    
    // 10. 超时控制中间件
    if cfg.Middleware.Timeout.Enable {
        timeoutMiddleware := middleware.NewTimeoutMiddleware(&cfg.Middleware.Timeout)
        s.engine.Use(timeoutMiddleware.Middleware())
    }
}

2. 配置文件示例

yaml
# config/config.yaml
app:
  name: "ledger-service"
  version: "v1.0.0"
  debug: false

http:
  addr: ":8080"
  read_timeout: "30s"
  write_timeout: "30s"
  idle_timeout: "60s"

middleware:
  # 异常恢复
  recovery:
    enable: true
    stack_trace: true
    
  # 请求日志
  logging:
    enable: true
    level: "info"
    format: "json"
    skip_paths:
      - "/health"
      - "/metrics"
      
  # 链路追踪
  tracing:
    enable: true
    service_name: "ledger-service"
    sampler_type: "probabilistic"
    sampler_param: 0.1
    jaeger_endpoint: "localhost:6831"
    
  # 性能监控
  metrics:
    enable: true
    namespace: "ledger"
    subsystem: "api"
    buckets: [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]
    
  # 限流配置
  rate_limit:
    enable: true
    algorithm: "token_bucket"
    global:
      qps: 1000
      burst: 2000
    per_user:
      qps: 100
      burst: 200
      window_size: "1m"
    per_api:
      "/api/v1/transactions":
        qps: 500
        burst: 1000
      "/api/v1/analytics":
        qps: 100
        burst: 200
        
  # 熔断配置
  circuit_breaker:
    enable: true
    failure_threshold: 5
    success_threshold: 3
    timeout: "30s"
    max_requests: 10
    interval: "10s"
    min_request_threshold: 10
    
  # 超时控制
  timeout:
    enable: true
    read_timeout: "30s"
    write_timeout: "30s"
    handler_timeout: "25s"

最佳实践

1. 中间件顺序优化

go
// 推荐的中间件顺序 (从外到内)
var middlewareOrder = []string{
    "Recovery",        // 1. 异常恢复 - 必须在最外层
    "Logging",         // 2. 请求日志 - 记录所有请求
    "Tracing",         // 3. 链路追踪 - 生成 trace ID
    "Metrics",         // 4. 性能监控 - 收集指标
    "CORS",           // 5. 跨域处理 - 处理预检请求
    "Security",       // 6. 安全头设置
    "Validation",     // 7. 请求验证 - 快速失败
    "Authentication", // 8. 身份认证
    "Authorization",  // 9. 权限授权
    "RateLimit",      // 10. 限流控制
    "CircuitBreaker", // 11. 熔断保护
    "Timeout",        // 12. 超时控制 - 最内层
}

// ❌ 错误的顺序示例
// 如果将 RateLimit 放在 Authentication 之前,
// 恶意用户可能通过大量未认证请求消耗限流配额

2. 性能优化技巧

go
// 使用对象池减少内存分配
var spanPool = sync.Pool{
    New: func() interface{} {
        return &SpanInfo{}
    },
}

func (m *TracingMiddleware) Middleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 从对象池获取对象
        spanInfo := spanPool.Get().(*SpanInfo)
        defer spanPool.Put(spanInfo)
        
        // 使用 spanInfo...
    }
}

// 缓存中间件实例
var middlewareCache sync.Map

func getCachedMiddleware(key string, factory func() gin.HandlerFunc) gin.HandlerFunc {
    if cached, ok := middlewareCache.Load(key); ok {
        return cached.(gin.HandlerFunc)
    }
    
    middleware := factory()
    middlewareCache.Store(key, middleware)
    return middleware
}

3. 监控和告警

go
// 中间件性能监控
type MiddlewareMetrics struct {
    executionTime *prometheus.HistogramVec
    errorCount    *prometheus.CounterVec
}

func (m *MiddlewareMetrics) recordExecution(name string, duration time.Duration, err error) {
    m.executionTime.WithLabelValues(name).Observe(duration.Seconds())
    
    if err != nil {
        m.errorCount.WithLabelValues(name, "error").Inc()
    } else {
        m.errorCount.WithLabelValues(name, "success").Inc()
    }
}

// 中间件包装器,用于统一监控
func WithMonitoring(name string, middleware gin.HandlerFunc) gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        
        defer func() {
            duration := time.Since(start)
            if r := recover(); r != nil {
                metrics.recordExecution(name, duration, fmt.Errorf("panic: %v", r))
                panic(r) // 重新抛出
            } else {
                metrics.recordExecution(name, duration, nil)
            }
        }()
        
        middleware(c)
    }
}

4. 错误处理策略

go
// 统一错误处理中间件
func ErrorHandlerMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Next()
        
        // 处理中间件中的错误
        for _, err := range c.Errors {
            switch e := err.Err.(type) {
            case *RateLimitError:
                c.JSON(http.StatusTooManyRequests, gin.H{
                    "error": "rate limit exceeded",
                    "retry_after": e.RetryAfter,
                })
                return
                
            case *CircuitBreakerError:
                c.JSON(http.StatusServiceUnavailable, gin.H{
                    "error": "service temporarily unavailable",
                    "retry_after": e.RetryAfter,
                })
                return
                
            case *AuthenticationError:
                c.JSON(http.StatusUnauthorized, gin.H{
                    "error": "authentication required",
                })
                return
                
            default:
                c.JSON(http.StatusInternalServerError, gin.H{
                    "error": "internal server error",
                })
                return
            }
        }
    }
}

性能考虑

1. 内存优化

go
// 使用 strings.Builder 减少字符串拼接分配
type LogEntry struct {
    builder strings.Builder
}

func (l *LogEntry) AddField(key, value string) {
    if l.builder.Len() > 0 {
        l.builder.WriteString(", ")
    }
    l.builder.WriteString(key)
    l.builder.WriteString("=")
    l.builder.WriteString(value)
}

// 预分配切片容量
headers := make([]string, 0, len(c.Request.Header))
for key := range c.Request.Header {
    headers = append(headers, key)
}

2. 并发优化

go
// 使用读写锁优化频繁读取的配置
type ConfigCache struct {
    config *MiddlewareConfig
    mutex  sync.RWMutex
}

func (cc *ConfigCache) GetConfig() *MiddlewareConfig {
    cc.mutex.RLock()
    defer cc.mutex.RUnlock()
    return cc.config
}

func (cc *ConfigCache) UpdateConfig(config *MiddlewareConfig) {
    cc.mutex.Lock()
    defer cc.mutex.Unlock()
    cc.config = config
}

3. 缓存策略

go
// LRU 缓存用于用户限流器
type UserLimiterCache struct {
    cache *lru.Cache
    mutex sync.Mutex
}

func (c *UserLimiterCache) GetLimiter(userID string) *rate.Limiter {
    c.mutex.Lock()
    defer c.mutex.Unlock()
    
    if limiter, ok := c.cache.Get(userID); ok {
        return limiter.(*rate.Limiter)
    }
    
    // 创建新的限流器
    limiter := rate.NewLimiter(rate.Limit(100), 200)
    c.cache.Add(userID, limiter)
    return limiter
}

常见问题

1. 中间件执行顺序问题

问题: 中间件执行顺序不当导致功能异常

解决方案:

go
// ✅ 正确的顺序
app.Use(Recovery())        // 最外层捕获所有异常
app.Use(Logging())         // 记录所有请求日志
app.Use(Authentication())  // 身份认证
app.Use(RateLimit())      // 限流 (在认证后,避免恶意消耗配额)

// ❌ 错误的顺序
app.Use(RateLimit())      // 在认证前限流,可能被滥用
app.Use(Authentication()) // 认证
app.Use(Recovery())       // 恢复在内层,可能漏掉异常

2. 内存泄漏问题

问题: 中间件中创建了大量对象导致内存泄漏

解决方案:

go
// ✅ 使用对象池
var bufferPool = sync.Pool{
    New: func() interface{} {
        return bytes.NewBuffer(make([]byte, 0, 1024))
    },
}

func processRequest(c *gin.Context) {
    buffer := bufferPool.Get().(*bytes.Buffer)
    buffer.Reset()
    defer bufferPool.Put(buffer)
    
    // 使用 buffer...
}

// ❌ 每次都创建新对象
func processRequest(c *gin.Context) {
    buffer := bytes.NewBuffer(nil) // 每次都分配新内存
    // 使用 buffer...
}

3. 并发安全问题

问题: 中间件中的共享状态导致并发问题

解决方案:

go
// ✅ 使用原子操作
type RequestCounter struct {
    count int64
}

func (rc *RequestCounter) Increment() {
    atomic.AddInt64(&rc.count, 1)
}

func (rc *RequestCounter) Get() int64 {
    return atomic.LoadInt64(&rc.count)
}

// ❌ 非原子操作
type RequestCounter struct {
    count int64
    mutex sync.Mutex
}

func (rc *RequestCounter) Increment() {
    rc.mutex.Lock()
    rc.count++ // 对于简单计数,互斥锁开销过大
    rc.mutex.Unlock()
}

参考资料

官方文档

  1. Kratos 官方文档 - Kratos 框架完整指南
  2. Gin 中间件文档 - Gin 中间件开发
  3. Prometheus 监控 - 监控指标设计
  4. Jaeger 追踪 - 分布式追踪

设计模式

  1. Middleware Pattern - 中间件模式
  2. Chain of Responsibility - 责任链模式
  3. Decorator Pattern - 装饰器模式

最佳实践

  1. Go 并发模式 - Go 并发编程
  2. 微服务监控 - 微服务可观测性
  3. API 限流策略 - 限流算法设计

工具和库

  1. golang.org/x/time/rate - Go 官方限流库
  2. hystrix-go - 熔断器实现
  3. prometheus/client_golang - Prometheus Go 客户端

实战练习

练习 1: 实现自定义限流中间件

实现一个支持多种算法的限流中间件:

  • 令牌桶算法
  • 滑动窗口算法
  • 固定窗口算法
  • 分布式限流支持

练习 2: 开发监控仪表板

为中间件开发监控仪表板:

  • 实时请求量监控
  • 响应时间分布
  • 错误率统计
  • 限流和熔断状态

练习 3: 性能基准测试

对中间件进行性能测试:

  • 中间件链的性能开销
  • 内存使用分析
  • 并发安全测试
  • 压力测试验证

这些练习将帮助你深入理解中间件的设计原理和实现细节,掌握在高并发场景下的优化技巧。

基于 MIT 许可证发布