🔄 重构案例实践
重构是改善代码内部结构而不改变外部行为的过程。本文档展示 lzt 项目中的实际重构案例和最佳实践。
🎯 重构原则
重构时机
重构步骤
- 确保测试覆盖 - 重构前必须有充分的测试
- 小步快走 - 每次只做一个小改动
- 频繁测试 - 每次改动后立即运行测试
- 及时提交 - 每个稳定状态都要提交
📝 实际重构案例
案例1: Bubble 进度条 View 函数重构
重构前 - 违反单一职责原则
go
// pkg/bubble/progress.go (重构前)
func (m ProgressModel) View() string {
var out strings.Builder
// 构建标题
out.WriteString(titleStyle.Render("Processing Tasks"))
out.WriteString("\n\n")
// 构建任务列表
for i, task := range m.tasks {
var status string
var icon string
// 状态判断逻辑
if task.IsCompleted() {
status = "completed"
icon = "✅"
} else if task.IsProcessing() {
status = "processing"
icon = "🔄"
} else if task.IsFailed() {
status = "failed"
icon = "❌"
} else {
status = "pending"
icon = "⏳"
}
// 任务渲染
taskText := fmt.Sprintf("%s %s", icon, task.Name())
if i == m.currentTaskIndex && task.IsProcessing() {
taskText = processingStyle.Render(taskText)
} else if task.IsCompleted() {
taskText = completedStyle.Render(taskText)
} else if task.IsFailed() {
taskText = errorStyle.Render(taskText)
} else {
taskText = pendingStyle.Render(taskText)
}
out.WriteString(taskText)
out.WriteString("\n")
// 进度条渲染
if task.IsProcessing() && task.HasProgress() {
progress := task.Progress()
progressBar := renderProgressBar(progress.Current, progress.Total)
out.WriteString(" " + progressBar)
out.WriteString(fmt.Sprintf(" %d/%d", progress.Current, progress.Total))
out.WriteString("\n")
}
}
// 总体进度
completed := 0
for _, task := range m.tasks {
if task.IsCompleted() {
completed++
}
}
totalProgress := renderProgressBar(completed, len(m.tasks))
out.WriteString("\n")
out.WriteString("Overall Progress: ")
out.WriteString(totalProgress)
out.WriteString(fmt.Sprintf(" %d/%d", completed, len(m.tasks)))
return out.String()
}重构后 - 职责分离
go
// pkg/bubble/progress.go (重构后)
func (m ProgressModel) View() string {
return m.renderer.Render(m)
}
// 新增渲染器组件
type ProgressRenderer struct {
titleRenderer *TitleRenderer
taskRenderer *TaskRenderer
statsRenderer *StatsRenderer
}
func NewProgressRenderer() *ProgressRenderer {
return &ProgressRenderer{
titleRenderer: NewTitleRenderer(),
taskRenderer: NewTaskRenderer(),
statsRenderer: NewStatsRenderer(),
}
}
func (r *ProgressRenderer) Render(model ProgressModel) string {
var out strings.Builder
// 渲染标题
out.WriteString(r.titleRenderer.Render("Processing Tasks"))
out.WriteString("\n\n")
// 渲染任务列表
out.WriteString(r.taskRenderer.RenderTasks(model.tasks, model.currentTaskIndex))
// 渲染总体统计
out.WriteString(r.statsRenderer.RenderOverallProgress(model.tasks))
return out.String()
}
// 任务渲染器
type TaskRenderer struct {
iconMap map[TaskStatus]string
styleMap map[TaskStatus]lipgloss.Style
}
func NewTaskRenderer() *TaskRenderer {
return &TaskRenderer{
iconMap: map[TaskStatus]string{
StatusCompleted: "✅",
StatusProcessing: "🔄",
StatusFailed: "❌",
StatusPending: "⏳",
},
styleMap: map[TaskStatus]string{
StatusCompleted: completedStyle,
StatusProcessing: processingStyle,
StatusFailed: errorStyle,
StatusPending: pendingStyle,
},
}
}
func (r *TaskRenderer) RenderTasks(tasks []Task, currentIndex int) string {
var out strings.Builder
for i, task := range tasks {
taskView := r.renderSingleTask(task, i == currentIndex)
out.WriteString(taskView)
out.WriteString("\n")
if task.IsProcessing() && task.HasProgress() {
progressView := r.renderTaskProgress(task)
out.WriteString(progressView)
out.WriteString("\n")
}
}
return out.String()
}
func (r *TaskRenderer) renderSingleTask(task Task, isCurrent bool) string {
status := task.Status()
icon := r.iconMap[status]
style := r.styleMap[status]
taskText := fmt.Sprintf("%s %s", icon, task.Name())
if isCurrent && task.IsProcessing() {
return processingStyle.Render(taskText)
}
return style.Render(taskText)
}
func (r *TaskRenderer) renderTaskProgress(task Task) string {
progress := task.Progress()
progressBar := renderProgressBar(progress.Current, progress.Total)
return fmt.Sprintf(" %s %d/%d", progressBar, progress.Current, progress.Total)
}
// 统计渲染器
type StatsRenderer struct{}
func NewStatsRenderer() *StatsRenderer {
return &StatsRenderer{}
}
func (r *StatsRenderer) RenderOverallProgress(tasks []Task) string {
completed := r.countCompleted(tasks)
total := len(tasks)
progressBar := renderProgressBar(completed, total)
return fmt.Sprintf("\nOverall Progress: %s %d/%d",
progressBar, completed, total)
}
func (r *StatsRenderer) countCompleted(tasks []Task) int {
count := 0
for _, task := range tasks {
if task.IsCompleted() {
count++
}
}
return count
}重构测试
go
// pkg/bubble/progress_test.go
func TestProgressRenderer_Render(t *testing.T) {
renderer := NewProgressRenderer()
tasks := []Task{
NewMockTask("Task 1", StatusCompleted),
NewMockTask("Task 2", StatusProcessing),
NewMockTask("Task 3", StatusPending),
}
model := ProgressModel{
tasks: tasks,
currentTaskIndex: 1,
renderer: renderer,
}
result := model.View()
assert.Contains(t, result, "Processing Tasks")
assert.Contains(t, result, "✅ Task 1")
assert.Contains(t, result, "🔄 Task 2")
assert.Contains(t, result, "⏳ Task 3")
assert.Contains(t, result, "Overall Progress:")
}
func TestTaskRenderer_RenderSingleTask(t *testing.T) {
renderer := NewTaskRenderer()
tests := []struct {
name string
task Task
isCurrent bool
expected string
}{
{
name: "completed task",
task: NewMockTask("Test Task", StatusCompleted),
isCurrent: false,
expected: "✅ Test Task",
},
{
name: "processing current task",
task: NewMockTask("Test Task", StatusProcessing),
isCurrent: true,
expected: "🔄 Test Task", // 会被 processingStyle 包装
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := renderer.renderSingleTask(tt.task, tt.isCurrent)
assert.Contains(t, result, tt.expected)
})
}
}案例2: Ledger 服务重构 - 从贫血模型到富领域模型
重构前 - 贫血模型
go
// internal/app/ledger/service.go (重构前)
type TransactionService struct {
repo TransactionRepository
}
type Transaction struct {
ID string `json:"id"`
LedgerID string `json:"ledger_id"`
Type string `json:"type"`
Amount int64 `json:"amount"`
Currency string `json:"currency"`
Description string `json:"description"`
CreatedAt time.Time `json:"created_at"`
}
func (s *TransactionService) CreateTransaction(ctx context.Context, req *CreateTransactionRequest) (*Transaction, error) {
// 验证逻辑分散在服务层
if req.LedgerID == "" {
return nil, errors.New("ledger_id is required")
}
if req.Amount <= 0 {
return nil, errors.New("amount must be positive")
}
if req.Currency == "" {
return nil, errors.New("currency is required")
}
if req.Type != "income" && req.Type != "expense" && req.Type != "transfer" {
return nil, errors.New("invalid transaction type")
}
tx := &Transaction{
ID: generateID(),
LedgerID: req.LedgerID,
Type: req.Type,
Amount: req.Amount,
Currency: req.Currency,
Description: req.Description,
CreatedAt: time.Now(),
}
if err := s.repo.Create(ctx, tx); err != nil {
return nil, err
}
return tx, nil
}
func (s *TransactionService) CalculateBalance(ctx context.Context, ledgerID string) (int64, error) {
transactions, err := s.repo.ListByLedger(ctx, ledgerID)
if err != nil {
return 0, err
}
var balance int64
for _, tx := range transactions {
if tx.Type == "income" {
balance += tx.Amount
} else if tx.Type == "expense" {
balance -= tx.Amount
}
// transfer 逻辑更复杂,此处简化
}
return balance, nil
}重构后 - 富领域模型
go
// internal/app/ledger/domain/transaction.go (重构后)
package domain
// 值对象
type Money struct {
Amount int64
Currency Currency
}
func NewMoney(amount int64, currency string) (Money, error) {
curr, err := ParseCurrency(currency)
if err != nil {
return Money{}, err
}
if amount < 0 {
return Money{}, errors.New("amount cannot be negative")
}
return Money{Amount: amount, Currency: curr}, nil
}
func (m Money) Add(other Money) (Money, error) {
if m.Currency != other.Currency {
return Money{}, errors.New("cannot add different currencies")
}
return Money{Amount: m.Amount + other.Amount, Currency: m.Currency}, nil
}
func (m Money) Subtract(other Money) (Money, error) {
if m.Currency != other.Currency {
return Money{}, errors.New("cannot subtract different currencies")
}
result := m.Amount - other.Amount
if result < 0 {
return Money{}, errors.New("insufficient funds")
}
return Money{Amount: result, Currency: m.Currency}, nil
}
// 枚举类型
type TransactionType int
const (
TransactionTypeIncome TransactionType = iota
TransactionTypeExpense
TransactionTypeTransfer
)
func ParseTransactionType(s string) (TransactionType, error) {
switch s {
case "income":
return TransactionTypeIncome, nil
case "expense":
return TransactionTypeExpense, nil
case "transfer":
return TransactionTypeTransfer, nil
default:
return 0, fmt.Errorf("invalid transaction type: %s", s)
}
}
// 实体
type Transaction struct {
id TransactionID
ledgerID LedgerID
txType TransactionType
amount Money
description string
createdAt time.Time
}
func NewTransaction(ledgerID LedgerID, txType TransactionType, amount Money, description string) (*Transaction, error) {
if err := validateTransactionCreation(ledgerID, txType, amount, description); err != nil {
return nil, err
}
return &Transaction{
id: NewTransactionID(),
ledgerID: ledgerID,
txType: txType,
amount: amount,
description: description,
createdAt: time.Now(),
}, nil
}
// 领域方法
func (t *Transaction) ChangeDescription(newDesc string) error {
if len(newDesc) > 500 {
return errors.New("description too long")
}
if strings.TrimSpace(newDesc) == "" {
return errors.New("description cannot be empty")
}
t.description = newDesc
return nil
}
func (t *Transaction) IsIncome() bool {
return t.txType == TransactionTypeIncome
}
func (t *Transaction) IsExpense() bool {
return t.txType == TransactionTypeExpense
}
func (t *Transaction) CalculateImpactOnBalance() Money {
switch t.txType {
case TransactionTypeIncome:
return t.amount
case TransactionTypeExpense:
return Money{Amount: -t.amount.Amount, Currency: t.amount.Currency}
default:
return Money{Amount: 0, Currency: t.amount.Currency}
}
}
// 领域验证
func validateTransactionCreation(ledgerID LedgerID, txType TransactionType, amount Money, description string) error {
if ledgerID == "" {
return errors.New("ledger ID is required")
}
if amount.Amount <= 0 {
return errors.New("amount must be positive")
}
if len(description) > 500 {
return errors.New("description too long")
}
return nil
}
// 聚合根
type Ledger struct {
id LedgerID
name string
transactions []*Transaction
balance Money
}
func (l *Ledger) AddTransaction(tx *Transaction) error {
if tx.ledgerID != l.id {
return errors.New("transaction belongs to different ledger")
}
l.transactions = append(l.transactions, tx)
// 更新余额
impact := tx.CalculateImpactOnBalance()
newBalance, err := l.balance.Add(impact)
if err != nil {
// 如果是支出导致余额不足,可能需要特殊处理
if tx.IsExpense() && strings.Contains(err.Error(), "insufficient") {
return errors.New("insufficient balance for this expense")
}
return err
}
l.balance = newBalance
return nil
}
func (l *Ledger) CalculateBalance() Money {
var balance Money
for _, tx := range l.transactions {
impact := tx.CalculateImpactOnBalance()
if impact.Amount > 0 {
balance, _ = balance.Add(impact)
} else {
balance, _ = balance.Subtract(Money{Amount: -impact.Amount, Currency: impact.Currency})
}
}
return balance
}重构后的应用服务
go
// internal/app/ledger/application/transaction_service.go
type TransactionApplicationService struct {
transactionRepo domain.TransactionRepository
ledgerRepo domain.LedgerRepository
eventBus EventBus
}
func (s *TransactionApplicationService) CreateTransaction(ctx context.Context, cmd CreateTransactionCommand) (*TransactionDTO, error) {
// 1. 构建值对象
amount, err := domain.NewMoney(cmd.Amount, cmd.Currency)
if err != nil {
return nil, fmt.Errorf("invalid amount: %w", err)
}
txType, err := domain.ParseTransactionType(cmd.Type)
if err != nil {
return nil, fmt.Errorf("invalid transaction type: %w", err)
}
// 2. 创建领域对象
tx, err := domain.NewTransaction(
domain.LedgerID(cmd.LedgerID),
txType,
amount,
cmd.Description,
)
if err != nil {
return nil, fmt.Errorf("create transaction failed: %w", err)
}
// 3. 检查业务规则(如果需要)
ledger, err := s.ledgerRepo.GetByID(ctx, domain.LedgerID(cmd.LedgerID))
if err != nil {
return nil, fmt.Errorf("ledger not found: %w", err)
}
if err := ledger.AddTransaction(tx); err != nil {
return nil, fmt.Errorf("add transaction to ledger failed: %w", err)
}
// 4. 持久化
if err := s.transactionRepo.Create(ctx, tx); err != nil {
return nil, fmt.Errorf("save transaction failed: %w", err)
}
if err := s.ledgerRepo.Update(ctx, ledger); err != nil {
return nil, fmt.Errorf("update ledger failed: %w", err)
}
// 5. 发布事件
event := NewTransactionCreatedEvent(tx)
s.eventBus.Publish(event)
return s.toDTO(tx), nil
}重构测试
go
func TestTransaction_NewTransaction(t *testing.T) {
tests := []struct {
name string
ledgerID string
txType string
amount int64
currency string
description string
wantErr bool
errContains string
}{
{
name: "valid transaction",
ledgerID: "ledger-123",
txType: "expense",
amount: 1000,
currency: "CNY",
description: "Lunch",
wantErr: false,
},
{
name: "invalid amount",
ledgerID: "ledger-123",
txType: "expense",
amount: -1000,
currency: "CNY",
description: "Invalid",
wantErr: true,
errContains: "amount cannot be negative",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
amount, err := domain.NewMoney(tt.amount, tt.currency)
if tt.wantErr && err != nil {
assert.Contains(t, err.Error(), tt.errContains)
return
}
require.NoError(t, err)
txType, err := domain.ParseTransactionType(tt.txType)
require.NoError(t, err)
tx, err := domain.NewTransaction(
domain.LedgerID(tt.ledgerID),
txType,
amount,
tt.description,
)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
} else {
assert.NoError(t, err)
assert.NotNil(t, tx)
}
})
}
}
func TestLedger_AddTransaction(t *testing.T) {
ledger := domain.NewLedger("test-ledger", "Test Ledger")
amount, _ := domain.NewMoney(1000, "CNY")
tx, _ := domain.NewTransaction(
domain.LedgerID("test-ledger"),
domain.TransactionTypeExpense,
amount,
"Test expense",
)
err := ledger.AddTransaction(tx)
assert.NoError(t, err)
balance := ledger.CalculateBalance()
assert.Equal(t, int64(-1000), balance.Amount)
}🛠️ 重构技巧
1. 提取方法 (Extract Method)
go
// 重构前 - 长方法
func (s *Service) ProcessComplexLogic(data *Data) error {
// 验证输入
if data == nil {
return errors.New("data cannot be nil")
}
if data.ID == "" {
return errors.New("ID is required")
}
if data.Value < 0 {
return errors.New("value must be positive")
}
// 处理逻辑A
processedA := data.Value * 2
if processedA > 1000 {
processedA = 1000
}
// 处理逻辑B
processedB := processedA + data.Offset
if processedB < 0 {
processedB = 0
}
// 保存结果
result := &Result{
ID: data.ID,
Value: processedB,
}
return s.repo.Save(result)
}
// 重构后 - 方法提取
func (s *Service) ProcessComplexLogic(data *Data) error {
if err := s.validateInput(data); err != nil {
return err
}
processedValue := s.applyProcessingLogic(data)
result := s.buildResult(data.ID, processedValue)
return s.repo.Save(result)
}
func (s *Service) validateInput(data *Data) error {
if data == nil {
return errors.New("data cannot be nil")
}
if data.ID == "" {
return errors.New("ID is required")
}
if data.Value < 0 {
return errors.New("value must be positive")
}
return nil
}
func (s *Service) applyProcessingLogic(data *Data) int64 {
processedA := s.processStepA(data.Value)
processedB := s.processStepB(processedA, data.Offset)
return processedB
}
func (s *Service) processStepA(value int64) int64 {
result := value * 2
if result > 1000 {
result = 1000
}
return result
}
func (s *Service) processStepB(valueA, offset int64) int64 {
result := valueA + offset
if result < 0 {
result = 0
}
return result
}
func (s *Service) buildResult(id string, value int64) *Result {
return &Result{
ID: id,
Value: value,
}
}2. 替换条件表达式为多态 (Replace Conditional with Polymorphism)
go
// 重构前 - 类型判断
func CalculateFee(transaction *Transaction) (int64, error) {
switch transaction.Type {
case "credit_card":
return int64(float64(transaction.Amount) * 0.03), nil
case "bank_transfer":
if transaction.Amount > 10000 {
return 50, nil
}
return 10, nil
case "cash":
return 0, nil
default:
return 0, errors.New("unknown transaction type")
}
}
// 重构后 - 多态
type FeeCalculator interface {
CalculateFee(amount int64) (int64, error)
}
type CreditCardFeeCalculator struct{}
func (c *CreditCardFeeCalculator) CalculateFee(amount int64) (int64, error) {
return int64(float64(amount) * 0.03), nil
}
type BankTransferFeeCalculator struct{}
func (b *BankTransferFeeCalculator) CalculateFee(amount int64) (int64, error) {
if amount > 10000 {
return 50, nil
}
return 10, nil
}
type CashFeeCalculator struct{}
func (c *CashFeeCalculator) CalculateFee(amount int64) (int64, error) {
return 0, nil
}
// 工厂方法
func NewFeeCalculator(transactionType string) (FeeCalculator, error) {
switch transactionType {
case "credit_card":
return &CreditCardFeeCalculator{}, nil
case "bank_transfer":
return &BankTransferFeeCalculator{}, nil
case "cash":
return &CashFeeCalculator{}, nil
default:
return nil, errors.New("unknown transaction type")
}
}
// 使用
func CalculateFee(transaction *Transaction) (int64, error) {
calculator, err := NewFeeCalculator(transaction.Type)
if err != nil {
return 0, err
}
return calculator.CalculateFee(transaction.Amount)
}3. 移除重复代码 (Remove Duplication)
go
// 重构前 - 重复代码
func (s *Service) CreateUser(req *CreateUserRequest) (*User, error) {
if req.Email == "" {
return nil, errors.New("email is required")
}
if !isValidEmail(req.Email) {
return nil, errors.New("invalid email format")
}
if req.Name == "" {
return nil, errors.New("name is required")
}
user := &User{
ID: generateID(),
Email: req.Email,
Name: req.Name,
}
if err := s.repo.Create(user); err != nil {
return nil, err
}
return user, nil
}
func (s *Service) UpdateUser(id string, req *UpdateUserRequest) (*User, error) {
if req.Email == "" {
return nil, errors.New("email is required")
}
if !isValidEmail(req.Email) {
return nil, errors.New("invalid email format")
}
if req.Name == "" {
return nil, errors.New("name is required")
}
user, err := s.repo.GetByID(id)
if err != nil {
return nil, err
}
user.Email = req.Email
user.Name = req.Name
if err := s.repo.Update(user); err != nil {
return nil, err
}
return user, nil
}
// 重构后 - 提取通用验证
type UserValidator struct{}
func (v *UserValidator) ValidateUserData(email, name string) error {
if email == "" {
return errors.New("email is required")
}
if !isValidEmail(email) {
return errors.New("invalid email format")
}
if name == "" {
return errors.New("name is required")
}
return nil
}
func (s *Service) CreateUser(req *CreateUserRequest) (*User, error) {
if err := s.validator.ValidateUserData(req.Email, req.Name); err != nil {
return nil, err
}
user := &User{
ID: generateID(),
Email: req.Email,
Name: req.Name,
}
if err := s.repo.Create(user); err != nil {
return nil, err
}
return user, nil
}
func (s *Service) UpdateUser(id string, req *UpdateUserRequest) (*User, error) {
if err := s.validator.ValidateUserData(req.Email, req.Name); err != nil {
return nil, err
}
user, err := s.repo.GetByID(id)
if err != nil {
return nil, err
}
user.Email = req.Email
user.Name = req.Name
if err := s.repo.Update(user); err != nil {
return nil, err
}
return user, nil
}📚 相关资源
项目实践文档
外部参考
- Refactoring: Improving the Design of Existing Code - Martin Fowler
- Clean Code - Robert C. Martin
- Working Effectively with Legacy Code - Michael Feathers
💡 重构建议: 重构是一个持续的过程,不应该等到代码完全无法维护时才开始。小步快走,频繁重构,保持代码的清洁和可维护性。