Skip to content

🔗 依赖注入

依赖注入是 lzt 项目架构设计的核心模式,通过 Wire 框架实现自动化依赖管理,提高代码的可测试性和可维护性。

🎯 依赖注入原理

控制反转 (IoC)

依赖注入的好处

  • 松耦合 - 减少组件间的直接依赖
  • 可测试性 - 便于模拟和单元测试
  • 可配置性 - 运行时选择不同实现
  • 可维护性 - 易于修改和扩展

🛠️ Wire 自动注入

Wire 基础配置

go
// wire.go
//go:build wireinject
// +build wireinject

package main

import (
    "github.com/google/wire"
    "lzt/internal/app/ledger"
    "lzt/internal/pkg/config"
    "lzt/internal/pkg/database"
)

// 提供者函数
func provideConfig() *config.Config {
    return config.Load()
}

func provideDatabase(cfg *config.Config) (*gorm.DB, error) {
    return database.NewConnection(cfg.Database)
}

func provideTransactionRepo(db *gorm.DB) ledger.TransactionRepository {
    return database.NewTransactionRepository(db)
}

func provideLedgerRepo(db *gorm.DB) ledger.LedgerRepository {
    return database.NewLedgerRepository(db)
}

func provideTransactionService(
    repo ledger.TransactionRepository,
    ledgerRepo ledger.LedgerRepository,
) *ledger.TransactionService {
    return ledger.NewTransactionService(repo, ledgerRepo)
}

// 提供者集合
var DatabaseSet = wire.NewSet(
    provideDatabase,
    provideTransactionRepo,
    provideLedgerRepo,
)

var ServiceSet = wire.NewSet(
    provideTransactionService,
    // 添加更多服务...
)

var ApplicationSet = wire.NewSet(
    provideConfig,
    DatabaseSet,
    ServiceSet,
)

// Wire 构建函数
func InitializeApplication() (*Application, error) {
    wire.Build(
        ApplicationSet,
        NewApplication, // 应用构造函数
    )
    return nil, nil
}

// 应用结构
type Application struct {
    Config             *config.Config
    TransactionService *ledger.TransactionService
    // 其他服务...
}

func NewApplication(
    cfg *config.Config,
    txService *ledger.TransactionService,
) *Application {
    return &Application{
        Config:             cfg,
        TransactionService: txService,
    }
}

生成依赖注入代码

bash
# 安装 Wire
go install github.com/google/wire/cmd/wire@latest

# 生成注入代码
wire gen ./...

# 生成的文件: wire_gen.go

生成的代码示例

go
// wire_gen.go
// Code generated by Wire. DO NOT EDIT.

//go:generate go run github.com/google/wire/cmd/wire
//go:build !wireinject
// +build !wireinject

package main

func InitializeApplication() (*Application, error) {
    config := provideConfig()
    db, err := provideDatabase(config)
    if err != nil {
        return nil, err
    }
    transactionRepository := provideTransactionRepo(db)
    ledgerRepository := provideLedgerRepo(db)
    transactionService := provideTransactionService(transactionRepository, ledgerRepository)
    application := NewApplication(config, transactionService)
    return application, nil
}

🏗️ 复杂依赖管理

接口绑定

go
// internal/app/ledger/wire.go
package ledger

import (
    "github.com/google/wire"
    "lzt/internal/pkg/database"
)

// 绑定接口到具体实现
var RepositorySet = wire.NewSet(
    database.NewTransactionRepository,
    wire.Bind(new(TransactionRepository), new(*database.GormTransactionRepository)),
    
    database.NewLedgerRepository,
    wire.Bind(new(LedgerRepository), new(*database.GormLedgerRepository)),
    
    database.NewBudgetRepository,
    wire.Bind(new(BudgetRepository), new(*database.GormBudgetRepository)),
)

// 条件注入
func provideNotificationService(cfg *config.Config) NotificationService {
    switch cfg.Notification.Type {
    case "email":
        return NewEmailNotificationService(cfg.Notification.Email)
    case "slack":
        return NewSlackNotificationService(cfg.Notification.Slack)
    default:
        return NewNoOpNotificationService()
    }
}

var NotificationSet = wire.NewSet(
    provideNotificationService,
)

可选依赖

go
// 可选依赖处理
type OptionalCache interface {
    Get(key string) (interface{}, bool)
    Set(key string, value interface{})
}

// 默认实现
type NoOpCache struct{}

func (c *NoOpCache) Get(key string) (interface{}, bool) { return nil, false }
func (c *NoOpCache) Set(key string, value interface{}) {}

// Redis 实现
type RedisCache struct {
    client *redis.Client
}

func (c *RedisCache) Get(key string) (interface{}, bool) {
    val, err := c.client.Get(key).Result()
    if err != nil {
        return nil, false
    }
    return val, true
}

func (c *RedisCache) Set(key string, value interface{}) {
    c.client.Set(key, value, time.Hour)
}

// 条件性提供缓存
func provideCache(cfg *config.Config) OptionalCache {
    if cfg.Redis.Enabled {
        client := redis.NewClient(&redis.Options{
            Addr: cfg.Redis.Address,
        })
        return &RedisCache{client: client}
    }
    return &NoOpCache{}
}

var CacheSet = wire.NewSet(
    provideCache,
)

泛型依赖注入

go
// 泛型仓储模式
type Repository[T any] interface {
    Create(ctx context.Context, entity *T) error
    GetByID(ctx context.Context, id string) (*T, error)
    Update(ctx context.Context, entity *T) error
    Delete(ctx context.Context, id string) error
}

// 具体实现
type GormRepository[T any] struct {
    db *gorm.DB
}

func NewGormRepository[T any](db *gorm.DB) *GormRepository[T] {
    return &GormRepository[T]{db: db}
}

func (r *GormRepository[T]) Create(ctx context.Context, entity *T) error {
    return r.db.WithContext(ctx).Create(entity).Error
}

func (r *GormRepository[T]) GetByID(ctx context.Context, id string) (*T, error) {
    var entity T
    err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
    if err != nil {
        return nil, err
    }
    return &entity, nil
}

// Wire 配置泛型
func provideTransactionRepository(db *gorm.DB) Repository[Transaction] {
    return NewGormRepository[Transaction](db)
}

func provideLedgerRepository(db *gorm.DB) Repository[Ledger] {
    return NewGormRepository[Ledger](db)
}

var GenericRepositorySet = wire.NewSet(
    provideTransactionRepository,
    provideLedgerRepository,
)

🧪 测试中的依赖注入

测试专用的依赖配置

go
// wire_test.go
//go:build wireinject
// +build wireinject

package ledger

import (
    "github.com/google/wire"
    "gorm.io/driver/sqlite"
    "gorm.io/gorm"
)

// 测试数据库提供者
func provideTestDatabase() (*gorm.DB, error) {
    db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
    if err != nil {
        return nil, err
    }
    
    // 自动迁移
    err = db.AutoMigrate(&Transaction{}, &Ledger{}, &Budget{})
    if err != nil {
        return nil, err
    }
    
    return db, nil
}

// 测试配置
func provideTestConfig() *config.Config {
    return &config.Config{
        Database: config.DatabaseConfig{
            Driver: "sqlite",
            DSN:    ":memory:",
        },
        // 其他测试配置...
    }
}

// 测试依赖集合
var TestSet = wire.NewSet(
    provideTestConfig,
    provideTestDatabase,
    RepositorySet,
    ServiceSet,
)

// 构建测试应用
func InitializeTestApplication() (*Application, error) {
    wire.Build(
        TestSet,
        NewApplication,
    )
    return nil, nil
}

模拟依赖

go
// mocks/transaction_repository.go
package mocks

import (
    "context"
    "github.com/stretchr/testify/mock"
    "lzt/internal/app/ledger"
)

type MockTransactionRepository struct {
    mock.Mock
}

func (m *MockTransactionRepository) Create(ctx context.Context, tx *ledger.Transaction) error {
    args := m.Called(ctx, tx)
    return args.Error(0)
}

func (m *MockTransactionRepository) GetByID(ctx context.Context, id string) (*ledger.Transaction, error) {
    args := m.Called(ctx, id)
    return args.Get(0).(*ledger.Transaction), args.Error(1)
}

// 测试中使用模拟
func TestTransactionService_Create(t *testing.T) {
    // 手动构建依赖
    mockRepo := new(mocks.MockTransactionRepository)
    mockLedgerRepo := new(mocks.MockLedgerRepository)
    
    service := ledger.NewTransactionService(mockRepo, mockLedgerRepo)
    
    // 设置期望
    mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*ledger.Transaction")).Return(nil)
    
    // 执行测试
    req := &ledger.CreateTransactionRequest{
        LedgerID: "test-ledger",
        Amount:   1000,
        Type:     "expense",
    }
    
    result, err := service.CreateTransaction(context.Background(), req)
    
    // 验证
    assert.NoError(t, err)
    assert.NotNil(t, result)
    mockRepo.AssertExpectations(t)
}

🔄 运行时依赖替换

插件系统

go
// plugin/manager.go
package plugin

import (
    "fmt"
    "plugin"
    "reflect"
)

type PluginManager struct {
    plugins map[string]*plugin.Plugin
    services map[reflect.Type]interface{}
}

func NewPluginManager() *PluginManager {
    return &PluginManager{
        plugins:  make(map[string]*plugin.Plugin),
        services: make(map[reflect.Type]interface{}),
    }
}

func (pm *PluginManager) LoadPlugin(name, path string) error {
    p, err := plugin.Open(path)
    if err != nil {
        return fmt.Errorf("failed to load plugin %s: %w", name, err)
    }
    
    pm.plugins[name] = p
    return nil
}

func (pm *PluginManager) GetService(serviceType reflect.Type) (interface{}, error) {
    if service, exists := pm.services[serviceType]; exists {
        return service, nil
    }
    
    // 尝试从插件中获取服务
    for name, p := range pm.plugins {
        symbol, err := p.Lookup("New" + serviceType.Name())
        if err != nil {
            continue
        }
        
        constructor, ok := symbol.(func() interface{})
        if !ok {
            continue
        }
        
        service := constructor()
        pm.services[serviceType] = service
        
        fmt.Printf("Loaded service %s from plugin %s\n", serviceType.Name(), name)
        return service, nil
    }
    
    return nil, fmt.Errorf("service %s not found", serviceType.Name())
}

// 使用插件系统
func (app *Application) LoadPlugins() error {
    pm := NewPluginManager()
    
    // 加载插件
    if err := pm.LoadPlugin("notification", "./plugins/notification.so"); err != nil {
        return err
    }
    
    // 获取服务
    notificationService, err := pm.GetService(reflect.TypeOf((*NotificationService)(nil)).Elem())
    if err != nil {
        return err
    }
    
    app.NotificationService = notificationService.(NotificationService)
    return nil
}

动态配置

go
// config/dynamic.go
package config

import (
    "context"
    "encoding/json"
    "time"
)

type DynamicConfig struct {
    config       *Config
    watchers     []ConfigWatcher
    updateChan   chan ConfigUpdate
    stopChan     chan struct{}
}

type ConfigWatcher interface {
    Watch(ctx context.Context) <-chan ConfigUpdate
}

type ConfigUpdate struct {
    Key   string
    Value interface{}
}

func NewDynamicConfig(initialConfig *Config) *DynamicConfig {
    return &DynamicConfig{
        config:     initialConfig,
        watchers:   make([]ConfigWatcher, 0),
        updateChan: make(chan ConfigUpdate, 100),
        stopChan:   make(chan struct{}),
    }
}

func (dc *DynamicConfig) AddWatcher(watcher ConfigWatcher) {
    dc.watchers = append(dc.watchers, watcher)
}

func (dc *DynamicConfig) Start(ctx context.Context) {
    for _, watcher := range dc.watchers {
        go func(w ConfigWatcher) {
            updateChan := w.Watch(ctx)
            for update := range updateChan {
                select {
                case dc.updateChan <- update:
                case <-dc.stopChan:
                    return
                }
            }
        }(watcher)
    }
    
    go dc.processUpdates()
}

func (dc *DynamicConfig) processUpdates() {
    for {
        select {
        case update := <-dc.updateChan:
            dc.applyUpdate(update)
        case <-dc.stopChan:
            return
        }
    }
}

func (dc *DynamicConfig) applyUpdate(update ConfigUpdate) {
    // 使用反射更新配置
    // 这里简化实现
    switch update.Key {
    case "database.max_connections":
        if value, ok := update.Value.(int); ok {
            dc.config.Database.MaxConnections = value
        }
    case "redis.enabled":
        if value, ok := update.Value.(bool); ok {
            dc.config.Redis.Enabled = value
        }
    }
}

func (dc *DynamicConfig) GetConfig() *Config {
    return dc.config
}

func (dc *DynamicConfig) Stop() {
    close(dc.stopChan)
}

📊 依赖图可视化

依赖分析工具

go
// tools/dependency_analyzer.go
package main

import (
    "fmt"
    "go/ast"
    "go/parser"
    "go/token"
    "os"
    "path/filepath"
    "strings"
)

type DependencyAnalyzer struct {
    dependencies map[string][]string
    packages     map[string]*ast.Package
}

func NewDependencyAnalyzer() *DependencyAnalyzer {
    return &DependencyAnalyzer{
        dependencies: make(map[string][]string),
        packages:     make(map[string]*ast.Package),
    }
}

func (da *DependencyAnalyzer) AnalyzeProject(rootPath string) error {
    return filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
        if err != nil {
            return err
        }
        
        if !strings.HasSuffix(path, ".go") || strings.Contains(path, "vendor/") {
            return nil
        }
        
        return da.analyzeFile(path)
    })
}

func (da *DependencyAnalyzer) analyzeFile(filePath string) error {
    fset := token.NewFileSet()
    node, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
    if err != nil {
        return err
    }
    
    packageName := node.Name.Name
    
    // 分析导入
    for _, imp := range node.Imports {
        importPath := strings.Trim(imp.Path.Value, "\"")
        da.dependencies[packageName] = append(da.dependencies[packageName], importPath)
    }
    
    return nil
}

func (da *DependencyAnalyzer) GenerateDotGraph() string {
    var builder strings.Builder
    
    builder.WriteString("digraph dependencies {\n")
    builder.WriteString("  rankdir=LR;\n")
    builder.WriteString("  node [shape=box];\n")
    
    for pkg, deps := range da.dependencies {
        for _, dep := range deps {
            // 只显示项目内部依赖
            if strings.Contains(dep, "lzt") {
                builder.WriteString(fmt.Sprintf("  \"%s\" -> \"%s\";\n", pkg, dep))
            }
        }
    }
    
    builder.WriteString("}\n")
    return builder.String()
}

func main() {
    analyzer := NewDependencyAnalyzer()
    
    err := analyzer.AnalyzeProject("./internal")
    if err != nil {
        panic(err)
    }
    
    dotGraph := analyzer.GenerateDotGraph()
    
    // 保存到文件
    err = os.WriteFile("dependencies.dot", []byte(dotGraph), 0644)
    if err != nil {
        panic(err)
    }
    
    fmt.Println("依赖图已生成: dependencies.dot")
    fmt.Println("使用 graphviz 生成图片: dot -Tpng dependencies.dot -o dependencies.png")
}

📚 最佳实践

1. 依赖注入设计原则

go
// 好的依赖注入设计
type UserService struct {
    repo         UserRepository     // 接口依赖
    emailService EmailService       // 接口依赖
    logger       Logger            // 接口依赖
}

// 构造函数依赖注入
func NewUserService(
    repo UserRepository,
    emailService EmailService,
    logger Logger,
) *UserService {
    return &UserService{
        repo:         repo,
        emailService: emailService,
        logger:       logger,
    }
}

// 避免的反模式
type BadUserService struct {
    // ❌ 直接依赖具体实现
    repo *MySQLUserRepository
    
    // ❌ 全局变量依赖
    // 在方法中直接使用 globalEmailService
}

func (s *BadUserService) CreateUser(user *User) error {
    // ❌ 在方法中创建依赖
    emailService := NewSMTPEmailService()
    
    // ❌ 直接访问全局变量
    globalLogger.Info("Creating user")
    
    return s.repo.Save(user)
}

2. 生命周期管理

go
// 单例模式
type DatabaseConnection struct {
    db *gorm.DB
}

var (
    dbInstance *DatabaseConnection
    dbOnce     sync.Once
)

func GetDatabaseConnection(cfg *config.Config) *DatabaseConnection {
    dbOnce.Do(func() {
        db, err := gorm.Open(mysql.Open(cfg.Database.DSN), &gorm.Config{})
        if err != nil {
            panic(err)
        }
        dbInstance = &DatabaseConnection{db: db}
    })
    return dbInstance
}

// 工厂模式
type RepositoryFactory interface {
    CreateUserRepository() UserRepository
    CreateOrderRepository() OrderRepository
}

type GormRepositoryFactory struct {
    db *gorm.DB
}

func (f *GormRepositoryFactory) CreateUserRepository() UserRepository {
    return NewGormUserRepository(f.db) // 每次创建新实例
}

func (f *GormRepositoryFactory) CreateOrderRepository() OrderRepository {
    return NewGormOrderRepository(f.db) // 每次创建新实例
}

📚 相关资源

项目文档

外部参考


💡 依赖注入建议: 保持依赖关系简单明确,优先使用接口而非具体实现,合理管理对象生命周期,避免循环依赖。

基于 MIT 许可证发布