Skip to content

⚡ Go 并发编程

并发是 Go 语言的核心特性之一,lzt 项目充分利用 Go 的并发能力来实现高性能的数据处理和用户界面。

🎯 并发编程原理

Goroutine 基础

CSP 模型 (Communicating Sequential Processes)

Go 采用 CSP 并发模型,核心思想是:

  • 不要通过共享内存来通信,而是通过通信来共享内存
  • 使用 channel 进行 goroutine 之间的通信
  • 避免共享状态,减少竞态条件

🔄 实际应用案例

案例1: Bubble 并发处理

go
// pkg/bubble/concurrent.go
package bubble

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// 并发处理选项
type Options struct {
    Workers     int           // 工作协程数量
    Timeout     time.Duration // 处理超时时间
    BufferSize  int           // 缓冲区大小
}

// 任务接口
type Task interface {
    Process() error
    ID() string
    Name() string
}

// 处理结果
type Result struct {
    TaskID string
    Error  error
    Data   interface{}
}

// 并发处理器
type ConcurrentProcessor struct {
    options Options
    workers []*Worker
    taskCh  chan Task
    resultCh chan Result
    ctx     context.Context
    cancel  context.CancelFunc
    wg      sync.WaitGroup
}

func NewConcurrentProcessor(opts Options) *ConcurrentProcessor {
    ctx, cancel := context.WithCancel(context.Background())
    
    return &ConcurrentProcessor{
        options:  opts,
        taskCh:   make(chan Task, opts.BufferSize),
        resultCh: make(chan Result, opts.BufferSize),
        ctx:      ctx,
        cancel:   cancel,
    }
}

// 启动工作协程
func (p *ConcurrentProcessor) Start() {
    for i := 0; i < p.options.Workers; i++ {
        worker := &Worker{
            id:       i,
            taskCh:   p.taskCh,
            resultCh: p.resultCh,
            ctx:      p.ctx,
        }
        
        p.workers = append(p.workers, worker)
        p.wg.Add(1)
        
        go func(w *Worker) {
            defer p.wg.Done()
            w.Run()
        }(worker)
    }
}

// 提交任务
func (p *ConcurrentProcessor) Submit(task Task) error {
    select {
    case p.taskCh <- task:
        return nil
    case <-p.ctx.Done():
        return p.ctx.Err()
    case <-time.After(p.options.Timeout):
        return fmt.Errorf("submit task timeout")
    }
}

// 获取结果
func (p *ConcurrentProcessor) Results() <-chan Result {
    return p.resultCh
}

// 停止处理器
func (p *ConcurrentProcessor) Stop() {
    close(p.taskCh)  // 关闭任务通道
    p.wg.Wait()      // 等待所有工作协程完成
    close(p.resultCh) // 关闭结果通道
    p.cancel()       // 取消上下文
}

// 工作协程
type Worker struct {
    id       int
    taskCh   <-chan Task
    resultCh chan<- Result
    ctx      context.Context
}

func (w *Worker) Run() {
    for {
        select {
        case task, ok := <-w.taskCh:
            if !ok {
                return // 任务通道已关闭
            }
            
            result := Result{
                TaskID: task.ID(),
                Error:  task.Process(),
                Data:   nil,
            }
            
            select {
            case w.resultCh <- result:
            case <-w.ctx.Done():
                return
            }
            
        case <-w.ctx.Done():
            return
        }
    }
}

// 便捷函数:并发处理字符串列表
func ProcessStrings(items []string, processFunc func(string) error) error {
    processor := NewConcurrentProcessor(Options{
        Workers:    4,
        Timeout:    30 * time.Second,
        BufferSize: len(items),
    })
    
    processor.Start()
    defer processor.Stop()
    
    // 提交任务
    for _, item := range items {
        task := &StringTask{
            item:    item,
            process: processFunc,
        }
        
        if err := processor.Submit(task); err != nil {
            return fmt.Errorf("submit task failed: %w", err)
        }
    }
    
    // 收集结果
    var errors []error
    for i := 0; i < len(items); i++ {
        result := <-processor.Results()
        if result.Error != nil {
            errors = append(errors, result.Error)
        }
    }
    
    if len(errors) > 0 {
        return fmt.Errorf("processing failed: %v", errors)
    }
    
    return nil
}

// 字符串任务实现
type StringTask struct {
    item    string
    process func(string) error
    id      string
}

func (t *StringTask) Process() error {
    return t.process(t.item)
}

func (t *StringTask) ID() string {
    if t.id == "" {
        t.id = fmt.Sprintf("string-task-%s", t.item)
    }
    return t.id
}

func (t *StringTask) Name() string {
    return fmt.Sprintf("Processing: %s", t.item)
}

案例2: Ledger 并发事务处理

go
// internal/app/ledger/concurrent_processor.go
package ledger

import (
    "context"
    "sync"
    "time"
)

// 事务处理器
type TransactionProcessor struct {
    batchSize    int
    flushTimeout time.Duration
    repo         TransactionRepository
    
    transactions chan *Transaction
    mu           sync.RWMutex
    batch        []*Transaction
    
    ctx    context.Context
    cancel context.CancelFunc
    wg     sync.WaitGroup
}

func NewTransactionProcessor(repo TransactionRepository) *TransactionProcessor {
    ctx, cancel := context.WithCancel(context.Background())
    
    processor := &TransactionProcessor{
        batchSize:    100,
        flushTimeout: 5 * time.Second,
        repo:         repo,
        transactions: make(chan *Transaction, 1000),
        batch:        make([]*Transaction, 0, 100),
        ctx:          ctx,
        cancel:       cancel,
    }
    
    processor.start()
    return processor
}

func (p *TransactionProcessor) start() {
    p.wg.Add(1)
    go func() {
        defer p.wg.Done()
        p.processBatches()
    }()
}

func (p *TransactionProcessor) processBatches() {
    ticker := time.NewTicker(p.flushTimeout)
    defer ticker.Stop()
    
    for {
        select {
        case tx := <-p.transactions:
            p.addToBatch(tx)
            
            if len(p.batch) >= p.batchSize {
                p.flushBatch()
            }
            
        case <-ticker.C:
            if len(p.batch) > 0 {
                p.flushBatch()
            }
            
        case <-p.ctx.Done():
            // 处理剩余的事务
            if len(p.batch) > 0 {
                p.flushBatch()
            }
            return
        }
    }
}

func (p *TransactionProcessor) addToBatch(tx *Transaction) {
    p.mu.Lock()
    defer p.mu.Unlock()
    
    p.batch = append(p.batch, tx)
}

func (p *TransactionProcessor) flushBatch() {
    p.mu.Lock()
    currentBatch := make([]*Transaction, len(p.batch))
    copy(currentBatch, p.batch)
    p.batch = p.batch[:0] // 清空batch但保留容量
    p.mu.Unlock()
    
    if len(currentBatch) == 0 {
        return
    }
    
    // 批量写入数据库
    if err := p.repo.BatchCreate(p.ctx, currentBatch); err != nil {
        // 错误处理:可以重试或记录日志
        for _, tx := range currentBatch {
            // 单独重试每个事务
            if retryErr := p.repo.Create(p.ctx, tx); retryErr != nil {
                // 记录失败的事务
                log.Printf("Failed to save transaction %s: %v", tx.ID, retryErr)
            }
        }
    }
}

func (p *TransactionProcessor) Submit(tx *Transaction) error {
    select {
    case p.transactions <- tx:
        return nil
    case <-p.ctx.Done():
        return p.ctx.Err()
    }
}

func (p *TransactionProcessor) Stop() {
    p.cancel()
    close(p.transactions)
    p.wg.Wait()
}

案例3: 实时进度更新

go
// pkg/bubble/realtime_progress.go
package bubble

import (
    "context"
    "sync"
    "time"
)

// 进度更新事件
type ProgressEvent struct {
    TaskID   string
    Progress int
    Total    int
    Message  string
    Error    error
}

// 实时进度管理器
type RealtimeProgressManager struct {
    events   chan ProgressEvent
    watchers map[string][]chan ProgressEvent
    mu       sync.RWMutex
    
    ctx    context.Context
    cancel context.CancelFunc
    wg     sync.WaitGroup
}

func NewRealtimeProgressManager() *RealtimeProgressManager {
    ctx, cancel := context.WithCancel(context.Background())
    
    manager := &RealtimeProgressManager{
        events:   make(chan ProgressEvent, 1000),
        watchers: make(map[string][]chan ProgressEvent),
        ctx:      ctx,
        cancel:   cancel,
    }
    
    manager.start()
    return manager
}

func (m *RealtimeProgressManager) start() {
    m.wg.Add(1)
    go func() {
        defer m.wg.Done()
        m.dispatchEvents()
    }()
}

func (m *RealtimeProgressManager) dispatchEvents() {
    for {
        select {
        case event := <-m.events:
            m.broadcastEvent(event)
            
        case <-m.ctx.Done():
            return
        }
    }
}

func (m *RealtimeProgressManager) broadcastEvent(event ProgressEvent) {
    m.mu.RLock()
    watchers := m.watchers[event.TaskID]
    // 创建副本避免死锁
    channels := make([]chan ProgressEvent, len(watchers))
    copy(channels, watchers)
    m.mu.RUnlock()
    
    for _, ch := range channels {
        select {
        case ch <- event:
        default:
            // 如果通道满了,跳过这个观察者
            // 避免阻塞其他观察者
        }
    }
}

func (m *RealtimeProgressManager) Watch(taskID string) <-chan ProgressEvent {
    ch := make(chan ProgressEvent, 10)
    
    m.mu.Lock()
    m.watchers[taskID] = append(m.watchers[taskID], ch)
    m.mu.Unlock()
    
    return ch
}

func (m *RealtimeProgressManager) Unwatch(taskID string, ch <-chan ProgressEvent) {
    m.mu.Lock()
    defer m.mu.Unlock()
    
    watchers := m.watchers[taskID]
    for i, watcher := range watchers {
        if watcher == ch {
            // 移除这个观察者
            m.watchers[taskID] = append(watchers[:i], watchers[i+1:]...)
            close(watcher)
            break
        }
    }
    
    // 如果没有观察者了,删除这个任务的映射
    if len(m.watchers[taskID]) == 0 {
        delete(m.watchers, taskID)
    }
}

func (m *RealtimeProgressManager) UpdateProgress(taskID string, progress, total int, message string) {
    event := ProgressEvent{
        TaskID:   taskID,
        Progress: progress,
        Total:    total,
        Message:  message,
    }
    
    select {
    case m.events <- event:
    default:
        // 事件队列满了,可能需要记录警告
    }
}

func (m *RealtimeProgressManager) ReportError(taskID string, err error) {
    event := ProgressEvent{
        TaskID: taskID,
        Error:  err,
    }
    
    select {
    case m.events <- event:
    default:
        // 错误事件应该优先处理
    }
}

func (m *RealtimeProgressManager) Stop() {
    m.cancel()
    close(m.events)
    
    // 关闭所有观察者通道
    m.mu.Lock()
    for taskID, watchers := range m.watchers {
        for _, ch := range watchers {
            close(ch)
        }
        delete(m.watchers, taskID)
    }
    m.mu.Unlock()
    
    m.wg.Wait()
}

🛡️ 并发安全模式

1. 互斥锁模式

go
// 安全的计数器
type SafeCounter struct {
    mu    sync.RWMutex
    count int64
}

func (c *SafeCounter) Increment() {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.count++
}

func (c *SafeCounter) Get() int64 {
    c.mu.RLock()
    defer c.mu.RUnlock()
    return c.count
}

func (c *SafeCounter) Add(delta int64) {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.count += delta
}

2. Channel 模式

go
// 使用 channel 实现安全的计数器
type ChannelCounter struct {
    ch chan int64
    value int64
}

func NewChannelCounter() *ChannelCounter {
    counter := &ChannelCounter{
        ch: make(chan int64),
    }
    
    go counter.run()
    return counter
}

func (c *ChannelCounter) run() {
    for delta := range c.ch {
        c.value += delta
    }
}

func (c *ChannelCounter) Add(delta int64) {
    c.ch <- delta
}

func (c *ChannelCounter) Get() int64 {
    // 这种模式下获取值比较复杂,需要另外设计
    // 或者使用 atomic 包
    return atomic.LoadInt64(&c.value)
}

3. 原子操作模式

go
// 使用原子操作的计数器
type AtomicCounter struct {
    count int64
}

func (c *AtomicCounter) Increment() {
    atomic.AddInt64(&c.count, 1)
}

func (c *AtomicCounter) Get() int64 {
    return atomic.LoadInt64(&c.count)
}

func (c *AtomicCounter) Add(delta int64) {
    atomic.AddInt64(&c.count, delta)
}

func (c *AtomicCounter) CompareAndSwap(old, new int64) bool {
    return atomic.CompareAndSwapInt64(&c.count, old, new)
}

⚡ 性能优化技巧

1. 工作池模式

go
// 可伸缩的工作池
type WorkerPool struct {
    workers    int
    maxWorkers int
    minWorkers int
    
    taskQueue  chan Task
    workerChan chan chan Task
    quit       chan bool
    
    mu sync.RWMutex
    activeWorkers int
}

func NewWorkerPool(min, max int) *WorkerPool {
    pool := &WorkerPool{
        minWorkers: min,
        maxWorkers: max,
        taskQueue:  make(chan Task, 1000),
        workerChan: make(chan chan Task, max),
        quit:       make(chan bool),
    }
    
    pool.start()
    return pool
}

func (p *WorkerPool) start() {
    // 启动最小数量的工作者
    for i := 0; i < p.minWorkers; i++ {
        p.createWorker()
    }
    
    // 启动调度器
    go p.dispatch()
    go p.monitor()
}

func (p *WorkerPool) dispatch() {
    for {
        select {
        case task := <-p.taskQueue:
            // 尝试分配给空闲的工作者
            select {
            case jobChannel := <-p.workerChan:
                jobChannel <- task
            default:
                // 没有空闲工作者,考虑创建新的
                if p.activeWorkers < p.maxWorkers {
                    p.createWorker()
                    jobChannel := <-p.workerChan
                    jobChannel <- task
                } else {
                    // 等待工作者空闲
                    jobChannel := <-p.workerChan
                    jobChannel <- task
                }
            }
        case <-p.quit:
            return
        }
    }
}

func (p *WorkerPool) monitor() {
    ticker := time.NewTicker(30 * time.Second)
    defer ticker.Stop()
    
    for {
        select {
        case <-ticker.C:
            p.scaleDown()
        case <-p.quit:
            return
        }
    }
}

func (p *WorkerPool) scaleDown() {
    p.mu.Lock()
    defer p.mu.Unlock()
    
    // 简单的缩容策略:如果任务队列为空且工作者数量大于最小值
    if len(p.taskQueue) == 0 && p.activeWorkers > p.minWorkers {
        // 发送停止信号给一个工作者
        select {
        case p.quit <- true:
            p.activeWorkers--
        default:
        }
    }
}

func (p *WorkerPool) createWorker() {
    p.mu.Lock()
    p.activeWorkers++
    p.mu.Unlock()
    
    worker := Worker{
        ID:         p.activeWorkers,
        WorkerChan: p.workerChan,
        JobChan:    make(chan Task),
        QuitChan:   make(chan bool),
    }
    
    go worker.Start()
}

func (p *WorkerPool) Submit(task Task) {
    p.taskQueue <- task
}

2. 流水线模式

go
// 数据处理流水线
type Pipeline struct {
    stages []Stage
}

type Stage interface {
    Process(input <-chan interface{}) <-chan interface{}
}

func NewPipeline(stages ...Stage) *Pipeline {
    return &Pipeline{stages: stages}
}

func (p *Pipeline) Execute(input <-chan interface{}) <-chan interface{} {
    current := input
    
    for _, stage := range p.stages {
        current = stage.Process(current)
    }
    
    return current
}

// 验证阶段
type ValidationStage struct{}

func (v *ValidationStage) Process(input <-chan interface{}) <-chan interface{} {
    output := make(chan interface{})
    
    go func() {
        defer close(output)
        
        for data := range input {
            if v.validate(data) {
                output <- data
            }
        }
    }()
    
    return output
}

func (v *ValidationStage) validate(data interface{}) bool {
    // 验证逻辑
    return true
}

// 转换阶段
type TransformStage struct {
    transformer func(interface{}) interface{}
}

func (t *TransformStage) Process(input <-chan interface{}) <-chan interface{} {
    output := make(chan interface{})
    
    go func() {
        defer close(output)
        
        for data := range input {
            output <- t.transformer(data)
        }
    }()
    
    return output
}

🧪 并发测试

竞态条件检测

go
func TestConcurrentAccess(t *testing.T) {
    counter := &SafeCounter{}
    
    const numGoroutines = 100
    const incrementsPerGoroutine = 1000
    
    var wg sync.WaitGroup
    
    for i := 0; i < numGoroutines; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            
            for j := 0; j < incrementsPerGoroutine; j++ {
                counter.Increment()
            }
        }()
    }
    
    wg.Wait()
    
    expected := int64(numGoroutines * incrementsPerGoroutine)
    actual := counter.Get()
    
    assert.Equal(t, expected, actual)
}

// 使用 go test -race 检测竞态条件
func TestRaceCondition(t *testing.T) {
    data := make(map[string]int)
    
    go func() {
        for i := 0; i < 1000; i++ {
            data["key"] = i
        }
    }()
    
    go func() {
        for i := 0; i < 1000; i++ {
            _ = data["key"]
        }
    }()
    
    time.Sleep(100 * time.Millisecond)
}

死锁检测

go
func TestDeadlockPrevention(t *testing.T) {
    mu1 := &sync.Mutex{}
    mu2 := &sync.Mutex{}
    
    done := make(chan bool, 2)
    
    // Goroutine 1: 先锁 mu1,再锁 mu2
    go func() {
        mu1.Lock()
        time.Sleep(10 * time.Millisecond)
        mu2.Lock()
        
        // 临界区
        
        mu2.Unlock()
        mu1.Unlock()
        done <- true
    }()
    
    // Goroutine 2: 先锁 mu1,再锁 mu2 (相同顺序,避免死锁)
    go func() {
        mu1.Lock()
        time.Sleep(10 * time.Millisecond)
        mu2.Lock()
        
        // 临界区
        
        mu2.Unlock()
        mu1.Unlock()
        done <- true
    }()
    
    // 等待完成或超时
    timeout := time.After(5 * time.Second)
    for i := 0; i < 2; i++ {
        select {
        case <-done:
            // 成功完成
        case <-timeout:
            t.Fatal("Deadlock detected")
        }
    }
}

📚 相关资源

项目实践文档

外部参考


💡 并发建议: 并发编程需要仔细设计,优先使用 channel 和 goroutine,避免过度使用锁。始终使用 go test -race 检测竞态条件。

基于 MIT 许可证发布