Skip to content

🔄 重构案例实践

重构是改善代码内部结构而不改变外部行为的过程。本文档展示 lzt 项目中的实际重构案例和最佳实践。

🎯 重构原则

重构时机

重构步骤

  1. 确保测试覆盖 - 重构前必须有充分的测试
  2. 小步快走 - 每次只做一个小改动
  3. 频繁测试 - 每次改动后立即运行测试
  4. 及时提交 - 每个稳定状态都要提交

📝 实际重构案例

案例1: Bubble 进度条 View 函数重构

重构前 - 违反单一职责原则

go
// pkg/bubble/progress.go (重构前)
func (m ProgressModel) View() string {
    var out strings.Builder
    
    // 构建标题
    out.WriteString(titleStyle.Render("Processing Tasks"))
    out.WriteString("\n\n")
    
    // 构建任务列表
    for i, task := range m.tasks {
        var status string
        var icon string
        
        // 状态判断逻辑
        if task.IsCompleted() {
            status = "completed"
            icon = "✅"
        } else if task.IsProcessing() {
            status = "processing"
            icon = "🔄"
        } else if task.IsFailed() {
            status = "failed"
            icon = "❌"
        } else {
            status = "pending"
            icon = "⏳"
        }
        
        // 任务渲染
        taskText := fmt.Sprintf("%s %s", icon, task.Name())
        if i == m.currentTaskIndex && task.IsProcessing() {
            taskText = processingStyle.Render(taskText)
        } else if task.IsCompleted() {
            taskText = completedStyle.Render(taskText)
        } else if task.IsFailed() {
            taskText = errorStyle.Render(taskText)
        } else {
            taskText = pendingStyle.Render(taskText)
        }
        
        out.WriteString(taskText)
        out.WriteString("\n")
        
        // 进度条渲染
        if task.IsProcessing() && task.HasProgress() {
            progress := task.Progress()
            progressBar := renderProgressBar(progress.Current, progress.Total)
            out.WriteString("  " + progressBar)
            out.WriteString(fmt.Sprintf(" %d/%d", progress.Current, progress.Total))
            out.WriteString("\n")
        }
    }
    
    // 总体进度
    completed := 0
    for _, task := range m.tasks {
        if task.IsCompleted() {
            completed++
        }
    }
    
    totalProgress := renderProgressBar(completed, len(m.tasks))
    out.WriteString("\n")
    out.WriteString("Overall Progress: ")
    out.WriteString(totalProgress)
    out.WriteString(fmt.Sprintf(" %d/%d", completed, len(m.tasks)))
    
    return out.String()
}

重构后 - 职责分离

go
// pkg/bubble/progress.go (重构后)
func (m ProgressModel) View() string {
    return m.renderer.Render(m)
}

// 新增渲染器组件
type ProgressRenderer struct {
    titleRenderer *TitleRenderer
    taskRenderer  *TaskRenderer
    statsRenderer *StatsRenderer
}

func NewProgressRenderer() *ProgressRenderer {
    return &ProgressRenderer{
        titleRenderer: NewTitleRenderer(),
        taskRenderer:  NewTaskRenderer(),
        statsRenderer: NewStatsRenderer(),
    }
}

func (r *ProgressRenderer) Render(model ProgressModel) string {
    var out strings.Builder
    
    // 渲染标题
    out.WriteString(r.titleRenderer.Render("Processing Tasks"))
    out.WriteString("\n\n")
    
    // 渲染任务列表
    out.WriteString(r.taskRenderer.RenderTasks(model.tasks, model.currentTaskIndex))
    
    // 渲染总体统计
    out.WriteString(r.statsRenderer.RenderOverallProgress(model.tasks))
    
    return out.String()
}

// 任务渲染器
type TaskRenderer struct {
    iconMap   map[TaskStatus]string
    styleMap  map[TaskStatus]lipgloss.Style
}

func NewTaskRenderer() *TaskRenderer {
    return &TaskRenderer{
        iconMap: map[TaskStatus]string{
            StatusCompleted:  "✅",
            StatusProcessing: "🔄", 
            StatusFailed:     "❌",
            StatusPending:    "⏳",
        },
        styleMap: map[TaskStatus]string{
            StatusCompleted:  completedStyle,
            StatusProcessing: processingStyle,
            StatusFailed:     errorStyle,
            StatusPending:    pendingStyle,
        },
    }
}

func (r *TaskRenderer) RenderTasks(tasks []Task, currentIndex int) string {
    var out strings.Builder
    
    for i, task := range tasks {
        taskView := r.renderSingleTask(task, i == currentIndex)
        out.WriteString(taskView)
        out.WriteString("\n")
        
        if task.IsProcessing() && task.HasProgress() {
            progressView := r.renderTaskProgress(task)
            out.WriteString(progressView)
            out.WriteString("\n")
        }
    }
    
    return out.String()
}

func (r *TaskRenderer) renderSingleTask(task Task, isCurrent bool) string {
    status := task.Status()
    icon := r.iconMap[status]
    style := r.styleMap[status]
    
    taskText := fmt.Sprintf("%s %s", icon, task.Name())
    
    if isCurrent && task.IsProcessing() {
        return processingStyle.Render(taskText)
    }
    
    return style.Render(taskText)
}

func (r *TaskRenderer) renderTaskProgress(task Task) string {
    progress := task.Progress()
    progressBar := renderProgressBar(progress.Current, progress.Total)
    return fmt.Sprintf("  %s %d/%d", progressBar, progress.Current, progress.Total)
}

// 统计渲染器
type StatsRenderer struct{}

func NewStatsRenderer() *StatsRenderer {
    return &StatsRenderer{}
}

func (r *StatsRenderer) RenderOverallProgress(tasks []Task) string {
    completed := r.countCompleted(tasks)
    total := len(tasks)
    
    progressBar := renderProgressBar(completed, total)
    
    return fmt.Sprintf("\nOverall Progress: %s %d/%d", 
        progressBar, completed, total)
}

func (r *StatsRenderer) countCompleted(tasks []Task) int {
    count := 0
    for _, task := range tasks {
        if task.IsCompleted() {
            count++
        }
    }
    return count
}

重构测试

go
// pkg/bubble/progress_test.go
func TestProgressRenderer_Render(t *testing.T) {
    renderer := NewProgressRenderer()
    
    tasks := []Task{
        NewMockTask("Task 1", StatusCompleted),
        NewMockTask("Task 2", StatusProcessing),
        NewMockTask("Task 3", StatusPending),
    }
    
    model := ProgressModel{
        tasks:            tasks,
        currentTaskIndex: 1,
        renderer:         renderer,
    }
    
    result := model.View()
    
    assert.Contains(t, result, "Processing Tasks")
    assert.Contains(t, result, "✅ Task 1")
    assert.Contains(t, result, "🔄 Task 2")
    assert.Contains(t, result, "⏳ Task 3")
    assert.Contains(t, result, "Overall Progress:")
}

func TestTaskRenderer_RenderSingleTask(t *testing.T) {
    renderer := NewTaskRenderer()
    
    tests := []struct {
        name      string
        task      Task
        isCurrent bool
        expected  string
    }{
        {
            name:      "completed task",
            task:      NewMockTask("Test Task", StatusCompleted),
            isCurrent: false,
            expected:  "✅ Test Task",
        },
        {
            name:      "processing current task", 
            task:      NewMockTask("Test Task", StatusProcessing),
            isCurrent: true,
            expected:  "🔄 Test Task", // 会被 processingStyle 包装
        },
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            result := renderer.renderSingleTask(tt.task, tt.isCurrent)
            assert.Contains(t, result, tt.expected)
        })
    }
}

案例2: Ledger 服务重构 - 从贫血模型到富领域模型

重构前 - 贫血模型

go
// internal/app/ledger/service.go (重构前)
type TransactionService struct {
    repo TransactionRepository
}

type Transaction struct {
    ID          string    `json:"id"`
    LedgerID    string    `json:"ledger_id"`
    Type        string    `json:"type"`
    Amount      int64     `json:"amount"`
    Currency    string    `json:"currency"`
    Description string    `json:"description"`
    CreatedAt   time.Time `json:"created_at"`
}

func (s *TransactionService) CreateTransaction(ctx context.Context, req *CreateTransactionRequest) (*Transaction, error) {
    // 验证逻辑分散在服务层
    if req.LedgerID == "" {
        return nil, errors.New("ledger_id is required")
    }
    
    if req.Amount <= 0 {
        return nil, errors.New("amount must be positive")
    }
    
    if req.Currency == "" {
        return nil, errors.New("currency is required")
    }
    
    if req.Type != "income" && req.Type != "expense" && req.Type != "transfer" {
        return nil, errors.New("invalid transaction type")
    }
    
    tx := &Transaction{
        ID:          generateID(),
        LedgerID:    req.LedgerID,
        Type:        req.Type,
        Amount:      req.Amount,
        Currency:    req.Currency,
        Description: req.Description,
        CreatedAt:   time.Now(),
    }
    
    if err := s.repo.Create(ctx, tx); err != nil {
        return nil, err
    }
    
    return tx, nil
}

func (s *TransactionService) CalculateBalance(ctx context.Context, ledgerID string) (int64, error) {
    transactions, err := s.repo.ListByLedger(ctx, ledgerID)
    if err != nil {
        return 0, err
    }
    
    var balance int64
    for _, tx := range transactions {
        if tx.Type == "income" {
            balance += tx.Amount
        } else if tx.Type == "expense" {
            balance -= tx.Amount
        }
        // transfer 逻辑更复杂,此处简化
    }
    
    return balance, nil
}

重构后 - 富领域模型

go
// internal/app/ledger/domain/transaction.go (重构后)
package domain

// 值对象
type Money struct {
    Amount   int64
    Currency Currency
}

func NewMoney(amount int64, currency string) (Money, error) {
    curr, err := ParseCurrency(currency)
    if err != nil {
        return Money{}, err
    }
    
    if amount < 0 {
        return Money{}, errors.New("amount cannot be negative")
    }
    
    return Money{Amount: amount, Currency: curr}, nil
}

func (m Money) Add(other Money) (Money, error) {
    if m.Currency != other.Currency {
        return Money{}, errors.New("cannot add different currencies")
    }
    return Money{Amount: m.Amount + other.Amount, Currency: m.Currency}, nil
}

func (m Money) Subtract(other Money) (Money, error) {
    if m.Currency != other.Currency {
        return Money{}, errors.New("cannot subtract different currencies")
    }
    
    result := m.Amount - other.Amount
    if result < 0 {
        return Money{}, errors.New("insufficient funds")
    }
    
    return Money{Amount: result, Currency: m.Currency}, nil
}

// 枚举类型
type TransactionType int

const (
    TransactionTypeIncome TransactionType = iota
    TransactionTypeExpense
    TransactionTypeTransfer
)

func ParseTransactionType(s string) (TransactionType, error) {
    switch s {
    case "income":
        return TransactionTypeIncome, nil
    case "expense":
        return TransactionTypeExpense, nil
    case "transfer":
        return TransactionTypeTransfer, nil
    default:
        return 0, fmt.Errorf("invalid transaction type: %s", s)
    }
}

// 实体
type Transaction struct {
    id          TransactionID
    ledgerID    LedgerID
    txType      TransactionType
    amount      Money
    description string
    createdAt   time.Time
}

func NewTransaction(ledgerID LedgerID, txType TransactionType, amount Money, description string) (*Transaction, error) {
    if err := validateTransactionCreation(ledgerID, txType, amount, description); err != nil {
        return nil, err
    }
    
    return &Transaction{
        id:          NewTransactionID(),
        ledgerID:    ledgerID,
        txType:      txType,
        amount:      amount,
        description: description,
        createdAt:   time.Now(),
    }, nil
}

// 领域方法
func (t *Transaction) ChangeDescription(newDesc string) error {
    if len(newDesc) > 500 {
        return errors.New("description too long")
    }
    
    if strings.TrimSpace(newDesc) == "" {
        return errors.New("description cannot be empty")
    }
    
    t.description = newDesc
    return nil
}

func (t *Transaction) IsIncome() bool {
    return t.txType == TransactionTypeIncome
}

func (t *Transaction) IsExpense() bool {
    return t.txType == TransactionTypeExpense
}

func (t *Transaction) CalculateImpactOnBalance() Money {
    switch t.txType {
    case TransactionTypeIncome:
        return t.amount
    case TransactionTypeExpense:
        return Money{Amount: -t.amount.Amount, Currency: t.amount.Currency}
    default:
        return Money{Amount: 0, Currency: t.amount.Currency}
    }
}

// 领域验证
func validateTransactionCreation(ledgerID LedgerID, txType TransactionType, amount Money, description string) error {
    if ledgerID == "" {
        return errors.New("ledger ID is required")
    }
    
    if amount.Amount <= 0 {
        return errors.New("amount must be positive")
    }
    
    if len(description) > 500 {
        return errors.New("description too long")
    }
    
    return nil
}

// 聚合根
type Ledger struct {
    id           LedgerID
    name         string
    transactions []*Transaction
    balance      Money
}

func (l *Ledger) AddTransaction(tx *Transaction) error {
    if tx.ledgerID != l.id {
        return errors.New("transaction belongs to different ledger")
    }
    
    l.transactions = append(l.transactions, tx)
    
    // 更新余额
    impact := tx.CalculateImpactOnBalance()
    newBalance, err := l.balance.Add(impact)
    if err != nil {
        // 如果是支出导致余额不足,可能需要特殊处理
        if tx.IsExpense() && strings.Contains(err.Error(), "insufficient") {
            return errors.New("insufficient balance for this expense")
        }
        return err
    }
    
    l.balance = newBalance
    return nil
}

func (l *Ledger) CalculateBalance() Money {
    var balance Money
    
    for _, tx := range l.transactions {
        impact := tx.CalculateImpactOnBalance()
        if impact.Amount > 0 {
            balance, _ = balance.Add(impact)
        } else {
            balance, _ = balance.Subtract(Money{Amount: -impact.Amount, Currency: impact.Currency})
        }
    }
    
    return balance
}

重构后的应用服务

go
// internal/app/ledger/application/transaction_service.go
type TransactionApplicationService struct {
    transactionRepo domain.TransactionRepository
    ledgerRepo      domain.LedgerRepository
    eventBus        EventBus
}

func (s *TransactionApplicationService) CreateTransaction(ctx context.Context, cmd CreateTransactionCommand) (*TransactionDTO, error) {
    // 1. 构建值对象
    amount, err := domain.NewMoney(cmd.Amount, cmd.Currency)
    if err != nil {
        return nil, fmt.Errorf("invalid amount: %w", err)
    }
    
    txType, err := domain.ParseTransactionType(cmd.Type)
    if err != nil {
        return nil, fmt.Errorf("invalid transaction type: %w", err)
    }
    
    // 2. 创建领域对象
    tx, err := domain.NewTransaction(
        domain.LedgerID(cmd.LedgerID),
        txType,
        amount,
        cmd.Description,
    )
    if err != nil {
        return nil, fmt.Errorf("create transaction failed: %w", err)
    }
    
    // 3. 检查业务规则(如果需要)
    ledger, err := s.ledgerRepo.GetByID(ctx, domain.LedgerID(cmd.LedgerID))
    if err != nil {
        return nil, fmt.Errorf("ledger not found: %w", err)
    }
    
    if err := ledger.AddTransaction(tx); err != nil {
        return nil, fmt.Errorf("add transaction to ledger failed: %w", err)
    }
    
    // 4. 持久化
    if err := s.transactionRepo.Create(ctx, tx); err != nil {
        return nil, fmt.Errorf("save transaction failed: %w", err)
    }
    
    if err := s.ledgerRepo.Update(ctx, ledger); err != nil {
        return nil, fmt.Errorf("update ledger failed: %w", err)
    }
    
    // 5. 发布事件
    event := NewTransactionCreatedEvent(tx)
    s.eventBus.Publish(event)
    
    return s.toDTO(tx), nil
}

重构测试

go
func TestTransaction_NewTransaction(t *testing.T) {
    tests := []struct {
        name        string
        ledgerID    string
        txType      string
        amount      int64
        currency    string
        description string
        wantErr     bool
        errContains string
    }{
        {
            name:        "valid transaction",
            ledgerID:    "ledger-123",
            txType:      "expense",
            amount:      1000,
            currency:    "CNY",
            description: "Lunch",
            wantErr:     false,
        },
        {
            name:        "invalid amount",
            ledgerID:    "ledger-123",
            txType:      "expense",
            amount:      -1000,
            currency:    "CNY",
            description: "Invalid",
            wantErr:     true,
            errContains: "amount cannot be negative",
        },
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            amount, err := domain.NewMoney(tt.amount, tt.currency)
            if tt.wantErr && err != nil {
                assert.Contains(t, err.Error(), tt.errContains)
                return
            }
            require.NoError(t, err)
            
            txType, err := domain.ParseTransactionType(tt.txType)
            require.NoError(t, err)
            
            tx, err := domain.NewTransaction(
                domain.LedgerID(tt.ledgerID),
                txType,
                amount,
                tt.description,
            )
            
            if tt.wantErr {
                assert.Error(t, err)
                assert.Contains(t, err.Error(), tt.errContains)
            } else {
                assert.NoError(t, err)
                assert.NotNil(t, tx)
            }
        })
    }
}

func TestLedger_AddTransaction(t *testing.T) {
    ledger := domain.NewLedger("test-ledger", "Test Ledger")
    
    amount, _ := domain.NewMoney(1000, "CNY")
    tx, _ := domain.NewTransaction(
        domain.LedgerID("test-ledger"),
        domain.TransactionTypeExpense,
        amount,
        "Test expense",
    )
    
    err := ledger.AddTransaction(tx)
    assert.NoError(t, err)
    
    balance := ledger.CalculateBalance()
    assert.Equal(t, int64(-1000), balance.Amount)
}

🛠️ 重构技巧

1. 提取方法 (Extract Method)

go
// 重构前 - 长方法
func (s *Service) ProcessComplexLogic(data *Data) error {
    // 验证输入
    if data == nil {
        return errors.New("data cannot be nil")
    }
    if data.ID == "" {
        return errors.New("ID is required")
    }
    if data.Value < 0 {
        return errors.New("value must be positive")
    }
    
    // 处理逻辑A
    processedA := data.Value * 2
    if processedA > 1000 {
        processedA = 1000
    }
    
    // 处理逻辑B
    processedB := processedA + data.Offset
    if processedB < 0 {
        processedB = 0
    }
    
    // 保存结果
    result := &Result{
        ID:    data.ID,
        Value: processedB,
    }
    
    return s.repo.Save(result)
}

// 重构后 - 方法提取
func (s *Service) ProcessComplexLogic(data *Data) error {
    if err := s.validateInput(data); err != nil {
        return err
    }
    
    processedValue := s.applyProcessingLogic(data)
    result := s.buildResult(data.ID, processedValue)
    
    return s.repo.Save(result)
}

func (s *Service) validateInput(data *Data) error {
    if data == nil {
        return errors.New("data cannot be nil")
    }
    if data.ID == "" {
        return errors.New("ID is required")
    }
    if data.Value < 0 {
        return errors.New("value must be positive")
    }
    return nil
}

func (s *Service) applyProcessingLogic(data *Data) int64 {
    processedA := s.processStepA(data.Value)
    processedB := s.processStepB(processedA, data.Offset)
    return processedB
}

func (s *Service) processStepA(value int64) int64 {
    result := value * 2
    if result > 1000 {
        result = 1000
    }
    return result
}

func (s *Service) processStepB(valueA, offset int64) int64 {
    result := valueA + offset
    if result < 0 {
        result = 0
    }
    return result
}

func (s *Service) buildResult(id string, value int64) *Result {
    return &Result{
        ID:    id,
        Value: value,
    }
}

2. 替换条件表达式为多态 (Replace Conditional with Polymorphism)

go
// 重构前 - 类型判断
func CalculateFee(transaction *Transaction) (int64, error) {
    switch transaction.Type {
    case "credit_card":
        return int64(float64(transaction.Amount) * 0.03), nil
    case "bank_transfer":
        if transaction.Amount > 10000 {
            return 50, nil
        }
        return 10, nil
    case "cash":
        return 0, nil
    default:
        return 0, errors.New("unknown transaction type")
    }
}

// 重构后 - 多态
type FeeCalculator interface {
    CalculateFee(amount int64) (int64, error)
}

type CreditCardFeeCalculator struct{}

func (c *CreditCardFeeCalculator) CalculateFee(amount int64) (int64, error) {
    return int64(float64(amount) * 0.03), nil
}

type BankTransferFeeCalculator struct{}

func (b *BankTransferFeeCalculator) CalculateFee(amount int64) (int64, error) {
    if amount > 10000 {
        return 50, nil
    }
    return 10, nil
}

type CashFeeCalculator struct{}

func (c *CashFeeCalculator) CalculateFee(amount int64) (int64, error) {
    return 0, nil
}

// 工厂方法
func NewFeeCalculator(transactionType string) (FeeCalculator, error) {
    switch transactionType {
    case "credit_card":
        return &CreditCardFeeCalculator{}, nil
    case "bank_transfer":
        return &BankTransferFeeCalculator{}, nil
    case "cash":
        return &CashFeeCalculator{}, nil
    default:
        return nil, errors.New("unknown transaction type")
    }
}

// 使用
func CalculateFee(transaction *Transaction) (int64, error) {
    calculator, err := NewFeeCalculator(transaction.Type)
    if err != nil {
        return 0, err
    }
    
    return calculator.CalculateFee(transaction.Amount)
}

3. 移除重复代码 (Remove Duplication)

go
// 重构前 - 重复代码
func (s *Service) CreateUser(req *CreateUserRequest) (*User, error) {
    if req.Email == "" {
        return nil, errors.New("email is required")
    }
    
    if !isValidEmail(req.Email) {
        return nil, errors.New("invalid email format")
    }
    
    if req.Name == "" {
        return nil, errors.New("name is required")
    }
    
    user := &User{
        ID:    generateID(),
        Email: req.Email,
        Name:  req.Name,
    }
    
    if err := s.repo.Create(user); err != nil {
        return nil, err
    }
    
    return user, nil
}

func (s *Service) UpdateUser(id string, req *UpdateUserRequest) (*User, error) {
    if req.Email == "" {
        return nil, errors.New("email is required")
    }
    
    if !isValidEmail(req.Email) {
        return nil, errors.New("invalid email format")
    }
    
    if req.Name == "" {
        return nil, errors.New("name is required")
    }
    
    user, err := s.repo.GetByID(id)
    if err != nil {
        return nil, err
    }
    
    user.Email = req.Email
    user.Name = req.Name
    
    if err := s.repo.Update(user); err != nil {
        return nil, err
    }
    
    return user, nil
}

// 重构后 - 提取通用验证
type UserValidator struct{}

func (v *UserValidator) ValidateUserData(email, name string) error {
    if email == "" {
        return errors.New("email is required")
    }
    
    if !isValidEmail(email) {
        return errors.New("invalid email format")
    }
    
    if name == "" {
        return errors.New("name is required")
    }
    
    return nil
}

func (s *Service) CreateUser(req *CreateUserRequest) (*User, error) {
    if err := s.validator.ValidateUserData(req.Email, req.Name); err != nil {
        return nil, err
    }
    
    user := &User{
        ID:    generateID(),
        Email: req.Email,
        Name:  req.Name,
    }
    
    if err := s.repo.Create(user); err != nil {
        return nil, err
    }
    
    return user, nil
}

func (s *Service) UpdateUser(id string, req *UpdateUserRequest) (*User, error) {
    if err := s.validator.ValidateUserData(req.Email, req.Name); err != nil {
        return nil, err
    }
    
    user, err := s.repo.GetByID(id)
    if err != nil {
        return nil, err
    }
    
    user.Email = req.Email
    user.Name = req.Name
    
    if err := s.repo.Update(user); err != nil {
        return nil, err
    }
    
    return user, nil
}

📚 相关资源

项目实践文档

外部参考


💡 重构建议: 重构是一个持续的过程,不应该等到代码完全无法维护时才开始。小步快走,频繁重构,保持代码的清洁和可维护性。

基于 MIT 许可证发布