Kratos 中间件设计模式
概述
Kratos 是 bilibili 开源的一套轻量级微服务框架,其中间件设计采用洋葱模型(Onion Model),提供了一套标准化的中间件接口和丰富的内置中间件。在 Ledger 模块中,我们采用 Kratos 的中间件架构来实现限流、熔断、链路追踪、监控等横切关注点。
核心特性
- 洋葱模型: 请求和响应都会经过中间件链
- 类型安全: 基于泛型的强类型中间件接口
- 可组合: 支持中间件的灵活组合和排序
- 高性能: 最小化的性能开销
- 可扩展: 易于扩展自定义中间件
应用场景
- API 限流和熔断保护
- 请求认证和授权
- 链路追踪和性能监控
- 请求日志和审计
- 错误处理和恢复
- 数据验证和转换
核心原理
中间件接口设计
Kratos 中间件基于函数式编程理念,采用高阶函数模式:
go
// 中间件函数签名
type Middleware func(Handler) Handler
// 处理器接口
type Handler func(ctx context.Context, req interface{}) (interface{}, error)洋葱模型执行流程
中间件链构建
go
// 中间件链构建
func Chain(middlewares ...Middleware) Middleware {
return func(handler Handler) Handler {
// 从后往前包装
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
}
// 使用示例
handler := Chain(
Recovery(), // 最外层:异常恢复
Logging(logger), // 请求日志
Metrics(), // 性能指标
RateLimit(limiter), // 限流控制
Auth(authFunc), // 身份认证
)(businessHandler)Ledger 模块中间件架构
1. 中间件执行顺序设计
2. 中间件配置管理
go
// 中间件配置结构
type MiddlewareConfig struct {
// 异常恢复配置
Recovery RecoveryConfig `yaml:"recovery"`
// 日志配置
Logging LoggingConfig `yaml:"logging"`
// 链路追踪配置
Tracing TracingConfig `yaml:"tracing"`
// 指标收集配置
Metrics MetricsConfig `yaml:"metrics"`
// 限流配置
RateLimit RateLimitConfig `yaml:"rate_limit"`
// 熔断配置
CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"`
// 超时配置
Timeout TimeoutConfig `yaml:"timeout"`
}
// 限流配置
type RateLimitConfig struct {
Enable bool `yaml:"enable"`
Algorithm string `yaml:"algorithm"` // token_bucket, sliding_window
Global GlobalLimitConfig `yaml:"global"`
PerUser PerUserLimitConfig `yaml:"per_user"`
PerAPI map[string]APILimitConfig `yaml:"per_api"`
}
type GlobalLimitConfig struct {
QPS int `yaml:"qps"` // 全局 QPS 限制
Burst int `yaml:"burst"` // 突发容量
}
type PerUserLimitConfig struct {
QPS int `yaml:"qps"`
Burst int `yaml:"burst"`
WindowSize time.Duration `yaml:"window_size"`
SlidingLogSize int `yaml:"sliding_log_size"`
}
// 熔断配置
type CircuitBreakerConfig struct {
Enable bool `yaml:"enable"`
FailureThreshold int `yaml:"failure_threshold"` // 失败阈值
SuccessThreshold int `yaml:"success_threshold"` // 成功阈值
Timeout time.Duration `yaml:"timeout"` // 超时时间
MaxRequests int `yaml:"max_requests"` // 半开状态最大请求数
Interval time.Duration `yaml:"interval"` // 统计间隔
MinRequestThreshold int `yaml:"min_request_threshold"` // 最小请求阈值
}核心中间件实现
1. 限流中间件 (Rate Limiter)
go
package middleware
import (
"context"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
"github.com/FixIterate/lz-stash/internal/pkg/ratelimit"
)
// RateLimiterMiddleware 限流中间件
type RateLimiterMiddleware struct {
globalLimiter *rate.Limiter
userLimiters *ratelimit.UserLimiterManager
apiLimiters map[string]*rate.Limiter
config *RateLimitConfig
}
// NewRateLimiterMiddleware 创建限流中间件
func NewRateLimiterMiddleware(config *RateLimitConfig) *RateLimiterMiddleware {
middleware := &RateLimiterMiddleware{
config: config,
apiLimiters: make(map[string]*rate.Limiter),
userLimiters: ratelimit.NewUserLimiterManager(config.PerUser),
}
// 初始化全局限流器
if config.Global.QPS > 0 {
middleware.globalLimiter = rate.NewLimiter(
rate.Limit(config.Global.QPS),
config.Global.Burst,
)
}
// 初始化 API 级限流器
for path, apiConfig := range config.PerAPI {
if apiConfig.QPS > 0 {
middleware.apiLimiters[path] = rate.NewLimiter(
rate.Limit(apiConfig.QPS),
apiConfig.Burst,
)
}
}
return middleware
}
// Middleware 返回 Gin 中间件函数
func (m *RateLimiterMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !m.config.Enable {
c.Next()
return
}
// 1. 全局限流检查
if m.globalLimiter != nil && !m.globalLimiter.Allow() {
m.handleRateLimitExceeded(c, "global", m.config.Global.QPS)
return
}
// 2. API 级限流检查
path := c.FullPath()
if apiLimiter, exists := m.apiLimiters[path]; exists {
if !apiLimiter.Allow() {
apiConfig := m.config.PerAPI[path]
m.handleRateLimitExceeded(c, "api", apiConfig.QPS)
return
}
}
// 3. 用户级限流检查
userID := m.extractUserID(c)
if userID != "" {
allowed, retryAfter := m.userLimiters.Allow(userID)
if !allowed {
m.handleUserRateLimitExceeded(c, userID, retryAfter)
return
}
}
// 4. 设置限流相关响应头
m.setRateLimitHeaders(c, userID)
c.Next()
}
}
// 处理限流超过情况
func (m *RateLimiterMiddleware) handleRateLimitExceeded(c *gin.Context, limitType string, qps int) {
c.Header("X-RateLimit-Limit", strconv.Itoa(qps))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Second).Unix(), 10))
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"type": limitType,
"limit": qps,
"retry_after": 1,
})
c.Abort()
}
// 处理用户级限流超过情况
func (m *RateLimiterMiddleware) handleUserRateLimitExceeded(c *gin.Context, userID string, retryAfter time.Duration) {
c.Header("X-RateLimit-Limit", strconv.Itoa(m.config.PerUser.QPS))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(retryAfter).Unix(), 10))
c.Header("Retry-After", strconv.Itoa(int(retryAfter.Seconds())))
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "user rate limit exceeded",
"user_id": userID,
"retry_after": int(retryAfter.Seconds()),
})
c.Abort()
}
// 设置限流相关响应头
func (m *RateLimiterMiddleware) setRateLimitHeaders(c *gin.Context, userID string) {
if userID != "" {
remaining := m.userLimiters.GetRemaining(userID)
c.Header("X-RateLimit-Limit", strconv.Itoa(m.config.PerUser.QPS))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
resetTime := m.userLimiters.GetResetTime(userID)
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime.Unix(), 10))
}
}
// 提取用户 ID
func (m *RateLimiterMiddleware) extractUserID(c *gin.Context) string {
// 优先从 JWT token 中提取
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(string); ok {
return id
}
}
// 备选方案:从 header 中提取
if userID := c.GetHeader("X-User-ID"); userID != "" {
return userID
}
// 最后方案:使用 IP 地址
return c.ClientIP()
}2. 熔断中间件 (Circuit Breaker)
go
package middleware
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/FixIterate/lz-stash/internal/pkg/circuitbreaker"
)
// CircuitBreakerState 熔断器状态
type CircuitBreakerState int
const (
StateClosed CircuitBreakerState = iota // 关闭状态
StateOpen // 开启状态
StateHalfOpen // 半开状态
)
// CircuitBreakerMiddleware 熔断中间件
type CircuitBreakerMiddleware struct {
breakers map[string]*CircuitBreaker
config *CircuitBreakerConfig
mutex sync.RWMutex
}
// CircuitBreaker 熔断器实现
type CircuitBreaker struct {
name string
state CircuitBreakerState
failureCount int
successCount int
requestCount int
lastFailureTime time.Time
lastSuccessTime time.Time
config *CircuitBreakerConfig
mutex sync.RWMutex
}
// NewCircuitBreakerMiddleware 创建熔断中间件
func NewCircuitBreakerMiddleware(config *CircuitBreakerConfig) *CircuitBreakerMiddleware {
return &CircuitBreakerMiddleware{
breakers: make(map[string]*CircuitBreaker),
config: config,
}
}
// Middleware 返回 Gin 中间件函数
func (m *CircuitBreakerMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !m.config.Enable {
c.Next()
return
}
// 获取或创建熔断器
breaker := m.getOrCreateBreaker(c.FullPath())
// 检查熔断器状态
if !breaker.AllowRequest() {
m.handleCircuitBreakerOpen(c, breaker)
return
}
// 记录请求开始时间
startTime := time.Now()
// 执行下一个中间件
c.Next()
// 根据响应状态更新熔断器
duration := time.Since(startTime)
statusCode := c.Writer.Status()
if m.isSuccess(statusCode, duration) {
breaker.RecordSuccess()
} else {
breaker.RecordFailure()
}
}
}
// getOrCreateBreaker 获取或创建熔断器
func (m *CircuitBreakerMiddleware) getOrCreateBreaker(path string) *CircuitBreaker {
m.mutex.RLock()
breaker, exists := m.breakers[path]
m.mutex.RUnlock()
if exists {
return breaker
}
m.mutex.Lock()
defer m.mutex.Unlock()
// 双重检查
if breaker, exists := m.breakers[path]; exists {
return breaker
}
// 创建新的熔断器
breaker = &CircuitBreaker{
name: path,
state: StateClosed,
config: m.config,
}
m.breakers[path] = breaker
return breaker
}
// AllowRequest 检查是否允许请求
func (cb *CircuitBreaker) AllowRequest() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
switch cb.state {
case StateClosed:
return true
case StateOpen:
// 检查是否可以转为半开状态
if time.Since(cb.lastFailureTime) > cb.config.Timeout {
cb.state = StateHalfOpen
cb.requestCount = 0
return true
}
return false
case StateHalfOpen:
// 半开状态下限制请求数量
return cb.requestCount < cb.config.MaxRequests
}
return false
}
// RecordSuccess 记录成功请求
func (cb *CircuitBreaker) RecordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.lastSuccessTime = time.Now()
cb.requestCount++
switch cb.state {
case StateClosed:
cb.failureCount = 0
case StateHalfOpen:
cb.successCount++
// 达到成功阈值,转为关闭状态
if cb.successCount >= cb.config.SuccessThreshold {
cb.state = StateClosed
cb.failureCount = 0
cb.successCount = 0
cb.requestCount = 0
}
}
}
// RecordFailure 记录失败请求
func (cb *CircuitBreaker) RecordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.lastFailureTime = time.Now()
cb.requestCount++
cb.failureCount++
switch cb.state {
case StateClosed:
// 检查是否需要开启熔断器
if cb.failureCount >= cb.config.FailureThreshold &&
cb.requestCount >= cb.config.MinRequestThreshold {
cb.state = StateOpen
}
case StateHalfOpen:
// 半开状态下出现失败,立即转为开启状态
cb.state = StateOpen
cb.successCount = 0
cb.requestCount = 0
}
}
// 处理熔断器开启状态
func (m *CircuitBreakerMiddleware) handleCircuitBreakerOpen(c *gin.Context, breaker *CircuitBreaker) {
c.Header("X-Circuit-Breaker-State", "OPEN")
retryAfter := int(m.config.Timeout.Seconds())
c.Header("Retry-After", fmt.Sprintf("%d", retryAfter))
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "service temporarily unavailable",
"reason": "circuit breaker is open",
"retry_after": retryAfter,
"path": breaker.name,
})
c.Abort()
}
// 判断请求是否成功
func (m *CircuitBreakerMiddleware) isSuccess(statusCode int, duration time.Duration) bool {
// HTTP 状态码检查
if statusCode >= 500 {
return false
}
// 超时检查
if duration > m.config.Timeout {
return false
}
return true
}
// GetBreakerState 获取熔断器状态 (用于监控)
func (m *CircuitBreakerMiddleware) GetBreakerState(path string) map[string]interface{} {
m.mutex.RLock()
breaker, exists := m.breakers[path]
m.mutex.RUnlock()
if !exists {
return map[string]interface{}{
"exists": false,
}
}
breaker.mutex.RLock()
defer breaker.mutex.RUnlock()
return map[string]interface{}{
"exists": true,
"name": breaker.name,
"state": m.stateToString(breaker.state),
"failure_count": breaker.failureCount,
"success_count": breaker.successCount,
"request_count": breaker.requestCount,
"last_failure": breaker.lastFailureTime,
"last_success": breaker.lastSuccessTime,
}
}
func (m *CircuitBreakerMiddleware) stateToString(state CircuitBreakerState) string {
switch state {
case StateClosed:
return "CLOSED"
case StateOpen:
return "OPEN"
case StateHalfOpen:
return "HALF_OPEN"
default:
return "UNKNOWN"
}
}3. 链路追踪中间件 (Tracing)
go
package middleware
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/opentracing/opentracing-go/log"
"github.com/uber/jaeger-client-go"
)
// TracingMiddleware 链路追踪中间件
type TracingMiddleware struct {
tracer opentracing.Tracer
config *TracingConfig
}
// TracingConfig 追踪配置
type TracingConfig struct {
Enable bool `yaml:"enable"`
ServiceName string `yaml:"service_name"`
SamplerType string `yaml:"sampler_type"` // const, probabilistic, ratelimiting
SamplerParam float64 `yaml:"sampler_param"`
JaegerEndpoint string `yaml:"jaeger_endpoint"`
}
// NewTracingMiddleware 创建链路追踪中间件
func NewTracingMiddleware(config *TracingConfig) (*TracingMiddleware, error) {
if !config.Enable {
return &TracingMiddleware{config: config}, nil
}
// 创建 Jaeger tracer
tracer, closer, err := createJaegerTracer(config)
if err != nil {
return nil, fmt.Errorf("failed to create tracer: %w", err)
}
// 设置全局 tracer
opentracing.SetGlobalTracer(tracer)
return &TracingMiddleware{
tracer: tracer,
config: config,
}, nil
}
// Middleware 返回 Gin 中间件函数
func (m *TracingMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !m.config.Enable {
c.Next()
return
}
// 构建 span 名称
spanName := fmt.Sprintf("%s %s", c.Request.Method, c.FullPath())
// 尝试从请求头中提取 parent span context
spanCtx, err := m.tracer.Extract(
opentracing.HTTPHeaders,
opentracing.HTTPHeadersCarrier(c.Request.Header),
)
var span opentracing.Span
if err != nil {
// 创建根 span
span = m.tracer.StartSpan(spanName)
} else {
// 创建子 span
span = m.tracer.StartSpan(spanName, opentracing.ChildOf(spanCtx))
}
defer span.Finish()
// 设置 span 标签
ext.HTTPMethod.Set(span, c.Request.Method)
ext.HTTPUrl.Set(span, c.Request.URL.String())
ext.Component.Set(span, "gin-http")
// 设置自定义标签
span.SetTag("service.name", m.config.ServiceName)
span.SetTag("http.path", c.FullPath())
span.SetTag("http.route", c.FullPath())
// 获取客户端信息
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
span.SetTag("http.user_agent", userAgent)
}
if clientIP := c.ClientIP(); clientIP != "" {
span.SetTag("http.client_ip", clientIP)
}
// 将 span 注入到 context 中
ctx := opentracing.ContextWithSpan(c.Request.Context(), span)
c.Request = c.Request.WithContext(ctx)
// 将 trace ID 和 span ID 添加到响应头 (用于日志关联)
if spanCtx := span.Context(); spanCtx != nil {
if jaegerCtx, ok := spanCtx.(jaeger.SpanContext); ok {
c.Header("X-Trace-ID", jaegerCtx.TraceID().String())
c.Header("X-Span-ID", jaegerCtx.SpanID().String())
// 将 trace ID 存储到 context 中供其他中间件使用
c.Set("trace_id", jaegerCtx.TraceID().String())
c.Set("span_id", jaegerCtx.SpanID().String())
}
}
// 执行下一个中间件
c.Next()
// 设置响应相关的标签和日志
statusCode := c.Writer.Status()
ext.HTTPStatusCode.Set(span, uint16(statusCode))
if statusCode >= 400 {
ext.Error.Set(span, true)
if statusCode >= 500 {
span.LogFields(
log.String("event", "error"),
log.String("level", "error"),
log.Int("status_code", statusCode),
)
}
}
// 记录响应大小
responseSize := c.Writer.Size()
if responseSize > 0 {
span.SetTag("http.response_size", responseSize)
}
// 记录错误信息
if len(c.Errors) > 0 {
span.LogFields(
log.String("event", "error"),
log.String("message", c.Errors.String()),
)
}
}
}
// createJaegerTracer 创建 Jaeger tracer
func createJaegerTracer(config *TracingConfig) (opentracing.Tracer, io.Closer, error) {
cfg := jaegerconfig.Configuration{
ServiceName: config.ServiceName,
Sampler: &jaegerconfig.SamplerConfig{
Type: config.SamplerType,
Param: config.SamplerParam,
},
Reporter: &jaegerconfig.ReporterConfig{
LogSpans: true,
LocalAgentHostPort: config.JaegerEndpoint,
},
}
tracer, closer, err := cfg.NewTracer(
jaegerconfig.Logger(jaeger.StdLogger),
)
if err != nil {
return nil, nil, err
}
return tracer, closer, nil
}
// StartSpanFromContext 从 context 中启动新的 span
func StartSpanFromContext(ctx context.Context, operationName string) (opentracing.Span, context.Context) {
span, ctx := opentracing.StartSpanFromContext(ctx, operationName)
return span, ctx
}
// LogError 记录错误到 span
func LogError(span opentracing.Span, err error) {
if span == nil || err == nil {
return
}
ext.Error.Set(span, true)
span.LogFields(
log.String("event", "error"),
log.String("message", err.Error()),
log.String("error.object", fmt.Sprintf("%+v", err)),
)
}
// AddSpanTags 添加多个标签到 span
func AddSpanTags(span opentracing.Span, tags map[string]interface{}) {
if span == nil {
return
}
for key, value := range tags {
span.SetTag(key, value)
}
}4. 性能监控中间件 (Metrics)
go
package middleware
import (
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// MetricsMiddleware 性能监控中间件
type MetricsMiddleware struct {
// HTTP 请求计数
httpRequestsTotal *prometheus.CounterVec
// HTTP 请求延迟
httpRequestDuration *prometheus.HistogramVec
// HTTP 请求大小
httpRequestSize *prometheus.HistogramVec
// HTTP 响应大小
httpResponseSize *prometheus.HistogramVec
// 当前处理中的请求数
httpRequestsInFlight prometheus.Gauge
config *MetricsConfig
}
// MetricsConfig 监控配置
type MetricsConfig struct {
Enable bool `yaml:"enable"`
Namespace string `yaml:"namespace"`
Subsystem string `yaml:"subsystem"`
Buckets []float64 `yaml:"buckets"`
}
// NewMetricsMiddleware 创建性能监控中间件
func NewMetricsMiddleware(config *MetricsConfig) *MetricsMiddleware {
if !config.Enable {
return &MetricsMiddleware{config: config}
}
// 默认延迟桶
buckets := config.Buckets
if len(buckets) == 0 {
buckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
}
return &MetricsMiddleware{
// HTTP 请求总数
httpRequestsTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: config.Namespace,
Subsystem: config.Subsystem,
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status_code"},
),
// HTTP 请求延迟分布
httpRequestDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: config.Namespace,
Subsystem: config.Subsystem,
Name: "http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: buckets,
},
[]string{"method", "path", "status_code"},
),
// HTTP 请求大小分布
httpRequestSize: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: config.Namespace,
Subsystem: config.Subsystem,
Name: "http_request_size_bytes",
Help: "HTTP request size in bytes",
Buckets: prometheus.ExponentialBuckets(1024, 2, 10), // 1KB, 2KB, 4KB, ...
},
[]string{"method", "path"},
),
// HTTP 响应大小分布
httpResponseSize: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: config.Namespace,
Subsystem: config.Subsystem,
Name: "http_response_size_bytes",
Help: "HTTP response size in bytes",
Buckets: prometheus.ExponentialBuckets(1024, 2, 10),
},
[]string{"method", "path", "status_code"},
),
// 当前处理中的请求数
httpRequestsInFlight: promauto.NewGauge(
prometheus.GaugeOpts{
Namespace: config.Namespace,
Subsystem: config.Subsystem,
Name: "http_requests_in_flight",
Help: "Current number of HTTP requests being processed",
},
),
config: config,
}
}
// Middleware 返回 Gin 中间件函数
func (m *MetricsMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !m.config.Enable {
c.Next()
return
}
// 记录开始时间
startTime := time.Now()
// 增加处理中的请求计数
m.httpRequestsInFlight.Inc()
defer m.httpRequestsInFlight.Dec()
// 记录请求大小
requestSize := computeRequestSize(c.Request)
// 执行下一个中间件
c.Next()
// 计算处理时间
duration := time.Since(startTime)
// 获取请求信息
method := c.Request.Method
path := c.FullPath()
statusCode := strconv.Itoa(c.Writer.Status())
// 记录指标
m.httpRequestsTotal.WithLabelValues(method, path, statusCode).Inc()
m.httpRequestDuration.WithLabelValues(method, path, statusCode).Observe(duration.Seconds())
m.httpRequestSize.WithLabelValues(method, path).Observe(float64(requestSize))
// 记录响应大小
responseSize := c.Writer.Size()
if responseSize > 0 {
m.httpResponseSize.WithLabelValues(method, path, statusCode).Observe(float64(responseSize))
}
}
}
// computeRequestSize 计算请求大小
func computeRequestSize(r *http.Request) int64 {
size := int64(0)
// URL 长度
if r.URL != nil {
size += int64(len(r.URL.String()))
}
// Headers 大小
for name, values := range r.Header {
size += int64(len(name))
for _, value := range values {
size += int64(len(value))
}
}
// Body 大小
if r.ContentLength > 0 {
size += r.ContentLength
}
return size
}
// RegisterCustomMetrics 注册自定义业务指标
func (m *MetricsMiddleware) RegisterCustomMetrics() {
// 注册业务相关的指标
promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: m.config.Namespace,
Subsystem: "business",
Name: "transactions_total",
Help: "Total number of transactions created",
},
[]string{"ledger_id", "transaction_type"},
)
promauto.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: m.config.Namespace,
Subsystem: "business",
Name: "active_users",
Help: "Current number of active users",
},
[]string{"time_window"},
)
promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: m.config.Namespace,
Subsystem: "business",
Name: "transaction_amount",
Help: "Transaction amount distribution",
Buckets: []float64{1, 10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000},
},
[]string{"currency", "transaction_type"},
)
}实现方案
1. 中间件注册和配置
go
package server
import (
"github.com/gin-gonic/gin"
"github.com/FixIterate/lz-stash/internal/middleware"
"github.com/FixIterate/lz-stash/internal/pkg/config"
)
// HTTPServer HTTP 服务器
type HTTPServer struct {
engine *gin.Engine
config *config.HTTPConfig
}
// NewHTTPServer 创建 HTTP 服务器
func NewHTTPServer(cfg *config.Config) *HTTPServer {
// 设置 Gin 模式
if cfg.App.Debug {
gin.SetMode(gin.DebugMode)
} else {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New()
server := &HTTPServer{
engine: engine,
config: &cfg.HTTP,
}
// 注册中间件
server.setupMiddlewares(cfg)
return server
}
// setupMiddlewares 设置中间件
func (s *HTTPServer) setupMiddlewares(cfg *config.Config) {
// 1. 异常恢复中间件 (最外层)
if cfg.Middleware.Recovery.Enable {
recoveryMiddleware := middleware.NewRecoveryMiddleware(&cfg.Middleware.Recovery)
s.engine.Use(recoveryMiddleware.Middleware())
}
// 2. 请求日志中间件
if cfg.Middleware.Logging.Enable {
loggingMiddleware := middleware.NewLoggingMiddleware(&cfg.Middleware.Logging)
s.engine.Use(loggingMiddleware.Middleware())
}
// 3. 链路追踪中间件
if cfg.Middleware.Tracing.Enable {
tracingMiddleware, err := middleware.NewTracingMiddleware(&cfg.Middleware.Tracing)
if err != nil {
panic(fmt.Sprintf("failed to create tracing middleware: %v", err))
}
s.engine.Use(tracingMiddleware.Middleware())
}
// 4. 性能监控中间件
if cfg.Middleware.Metrics.Enable {
metricsMiddleware := middleware.NewMetricsMiddleware(&cfg.Middleware.Metrics)
s.engine.Use(metricsMiddleware.Middleware())
// 注册业务指标
metricsMiddleware.RegisterCustomMetrics()
}
// 5. CORS 中间件
if cfg.Middleware.CORS.Enable {
corsMiddleware := middleware.NewCORSMiddleware(&cfg.Middleware.CORS)
s.engine.Use(corsMiddleware.Middleware())
}
// 6. 请求验证中间件
if cfg.Middleware.Validation.Enable {
validationMiddleware := middleware.NewValidationMiddleware(&cfg.Middleware.Validation)
s.engine.Use(validationMiddleware.Middleware())
}
// 7. 身份认证中间件
if cfg.Middleware.Auth.Enable {
authMiddleware := middleware.NewAuthMiddleware(&cfg.Middleware.Auth)
s.engine.Use(authMiddleware.Middleware())
}
// 8. 限流中间件
if cfg.Middleware.RateLimit.Enable {
rateLimitMiddleware := middleware.NewRateLimiterMiddleware(&cfg.Middleware.RateLimit)
s.engine.Use(rateLimitMiddleware.Middleware())
}
// 9. 熔断中间件
if cfg.Middleware.CircuitBreaker.Enable {
circuitBreakerMiddleware := middleware.NewCircuitBreakerMiddleware(&cfg.Middleware.CircuitBreaker)
s.engine.Use(circuitBreakerMiddleware.Middleware())
}
// 10. 超时控制中间件
if cfg.Middleware.Timeout.Enable {
timeoutMiddleware := middleware.NewTimeoutMiddleware(&cfg.Middleware.Timeout)
s.engine.Use(timeoutMiddleware.Middleware())
}
}2. 配置文件示例
yaml
# config/config.yaml
app:
name: "ledger-service"
version: "v1.0.0"
debug: false
http:
addr: ":8080"
read_timeout: "30s"
write_timeout: "30s"
idle_timeout: "60s"
middleware:
# 异常恢复
recovery:
enable: true
stack_trace: true
# 请求日志
logging:
enable: true
level: "info"
format: "json"
skip_paths:
- "/health"
- "/metrics"
# 链路追踪
tracing:
enable: true
service_name: "ledger-service"
sampler_type: "probabilistic"
sampler_param: 0.1
jaeger_endpoint: "localhost:6831"
# 性能监控
metrics:
enable: true
namespace: "ledger"
subsystem: "api"
buckets: [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]
# 限流配置
rate_limit:
enable: true
algorithm: "token_bucket"
global:
qps: 1000
burst: 2000
per_user:
qps: 100
burst: 200
window_size: "1m"
per_api:
"/api/v1/transactions":
qps: 500
burst: 1000
"/api/v1/analytics":
qps: 100
burst: 200
# 熔断配置
circuit_breaker:
enable: true
failure_threshold: 5
success_threshold: 3
timeout: "30s"
max_requests: 10
interval: "10s"
min_request_threshold: 10
# 超时控制
timeout:
enable: true
read_timeout: "30s"
write_timeout: "30s"
handler_timeout: "25s"最佳实践
1. 中间件顺序优化
go
// 推荐的中间件顺序 (从外到内)
var middlewareOrder = []string{
"Recovery", // 1. 异常恢复 - 必须在最外层
"Logging", // 2. 请求日志 - 记录所有请求
"Tracing", // 3. 链路追踪 - 生成 trace ID
"Metrics", // 4. 性能监控 - 收集指标
"CORS", // 5. 跨域处理 - 处理预检请求
"Security", // 6. 安全头设置
"Validation", // 7. 请求验证 - 快速失败
"Authentication", // 8. 身份认证
"Authorization", // 9. 权限授权
"RateLimit", // 10. 限流控制
"CircuitBreaker", // 11. 熔断保护
"Timeout", // 12. 超时控制 - 最内层
}
// ❌ 错误的顺序示例
// 如果将 RateLimit 放在 Authentication 之前,
// 恶意用户可能通过大量未认证请求消耗限流配额2. 性能优化技巧
go
// 使用对象池减少内存分配
var spanPool = sync.Pool{
New: func() interface{} {
return &SpanInfo{}
},
}
func (m *TracingMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 从对象池获取对象
spanInfo := spanPool.Get().(*SpanInfo)
defer spanPool.Put(spanInfo)
// 使用 spanInfo...
}
}
// 缓存中间件实例
var middlewareCache sync.Map
func getCachedMiddleware(key string, factory func() gin.HandlerFunc) gin.HandlerFunc {
if cached, ok := middlewareCache.Load(key); ok {
return cached.(gin.HandlerFunc)
}
middleware := factory()
middlewareCache.Store(key, middleware)
return middleware
}3. 监控和告警
go
// 中间件性能监控
type MiddlewareMetrics struct {
executionTime *prometheus.HistogramVec
errorCount *prometheus.CounterVec
}
func (m *MiddlewareMetrics) recordExecution(name string, duration time.Duration, err error) {
m.executionTime.WithLabelValues(name).Observe(duration.Seconds())
if err != nil {
m.errorCount.WithLabelValues(name, "error").Inc()
} else {
m.errorCount.WithLabelValues(name, "success").Inc()
}
}
// 中间件包装器,用于统一监控
func WithMonitoring(name string, middleware gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
defer func() {
duration := time.Since(start)
if r := recover(); r != nil {
metrics.recordExecution(name, duration, fmt.Errorf("panic: %v", r))
panic(r) // 重新抛出
} else {
metrics.recordExecution(name, duration, nil)
}
}()
middleware(c)
}
}4. 错误处理策略
go
// 统一错误处理中间件
func ErrorHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 处理中间件中的错误
for _, err := range c.Errors {
switch e := err.Err.(type) {
case *RateLimitError:
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"retry_after": e.RetryAfter,
})
return
case *CircuitBreakerError:
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "service temporarily unavailable",
"retry_after": e.RetryAfter,
})
return
case *AuthenticationError:
c.JSON(http.StatusUnauthorized, gin.H{
"error": "authentication required",
})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal server error",
})
return
}
}
}
}性能考虑
1. 内存优化
go
// 使用 strings.Builder 减少字符串拼接分配
type LogEntry struct {
builder strings.Builder
}
func (l *LogEntry) AddField(key, value string) {
if l.builder.Len() > 0 {
l.builder.WriteString(", ")
}
l.builder.WriteString(key)
l.builder.WriteString("=")
l.builder.WriteString(value)
}
// 预分配切片容量
headers := make([]string, 0, len(c.Request.Header))
for key := range c.Request.Header {
headers = append(headers, key)
}2. 并发优化
go
// 使用读写锁优化频繁读取的配置
type ConfigCache struct {
config *MiddlewareConfig
mutex sync.RWMutex
}
func (cc *ConfigCache) GetConfig() *MiddlewareConfig {
cc.mutex.RLock()
defer cc.mutex.RUnlock()
return cc.config
}
func (cc *ConfigCache) UpdateConfig(config *MiddlewareConfig) {
cc.mutex.Lock()
defer cc.mutex.Unlock()
cc.config = config
}3. 缓存策略
go
// LRU 缓存用于用户限流器
type UserLimiterCache struct {
cache *lru.Cache
mutex sync.Mutex
}
func (c *UserLimiterCache) GetLimiter(userID string) *rate.Limiter {
c.mutex.Lock()
defer c.mutex.Unlock()
if limiter, ok := c.cache.Get(userID); ok {
return limiter.(*rate.Limiter)
}
// 创建新的限流器
limiter := rate.NewLimiter(rate.Limit(100), 200)
c.cache.Add(userID, limiter)
return limiter
}常见问题
1. 中间件执行顺序问题
问题: 中间件执行顺序不当导致功能异常
解决方案:
go
// ✅ 正确的顺序
app.Use(Recovery()) // 最外层捕获所有异常
app.Use(Logging()) // 记录所有请求日志
app.Use(Authentication()) // 身份认证
app.Use(RateLimit()) // 限流 (在认证后,避免恶意消耗配额)
// ❌ 错误的顺序
app.Use(RateLimit()) // 在认证前限流,可能被滥用
app.Use(Authentication()) // 认证
app.Use(Recovery()) // 恢复在内层,可能漏掉异常2. 内存泄漏问题
问题: 中间件中创建了大量对象导致内存泄漏
解决方案:
go
// ✅ 使用对象池
var bufferPool = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, 1024))
},
}
func processRequest(c *gin.Context) {
buffer := bufferPool.Get().(*bytes.Buffer)
buffer.Reset()
defer bufferPool.Put(buffer)
// 使用 buffer...
}
// ❌ 每次都创建新对象
func processRequest(c *gin.Context) {
buffer := bytes.NewBuffer(nil) // 每次都分配新内存
// 使用 buffer...
}3. 并发安全问题
问题: 中间件中的共享状态导致并发问题
解决方案:
go
// ✅ 使用原子操作
type RequestCounter struct {
count int64
}
func (rc *RequestCounter) Increment() {
atomic.AddInt64(&rc.count, 1)
}
func (rc *RequestCounter) Get() int64 {
return atomic.LoadInt64(&rc.count)
}
// ❌ 非原子操作
type RequestCounter struct {
count int64
mutex sync.Mutex
}
func (rc *RequestCounter) Increment() {
rc.mutex.Lock()
rc.count++ // 对于简单计数,互斥锁开销过大
rc.mutex.Unlock()
}参考资料
官方文档
- Kratos 官方文档 - Kratos 框架完整指南
- Gin 中间件文档 - Gin 中间件开发
- Prometheus 监控 - 监控指标设计
- Jaeger 追踪 - 分布式追踪
设计模式
- Middleware Pattern - 中间件模式
- Chain of Responsibility - 责任链模式
- Decorator Pattern - 装饰器模式
最佳实践
工具和库
- golang.org/x/time/rate - Go 官方限流库
- hystrix-go - 熔断器实现
- prometheus/client_golang - Prometheus Go 客户端
实战练习
练习 1: 实现自定义限流中间件
实现一个支持多种算法的限流中间件:
- 令牌桶算法
- 滑动窗口算法
- 固定窗口算法
- 分布式限流支持
练习 2: 开发监控仪表板
为中间件开发监控仪表板:
- 实时请求量监控
- 响应时间分布
- 错误率统计
- 限流和熔断状态
练习 3: 性能基准测试
对中间件进行性能测试:
- 中间件链的性能开销
- 内存使用分析
- 并发安全测试
- 压力测试验证
这些练习将帮助你深入理解中间件的设计原理和实现细节,掌握在高并发场景下的优化技巧。