timer/wheel.go

117 lines
2.0 KiB
Go
Raw Normal View History

package timer
import (
"container/list"
"sync"
"time"
)
// TimeWheel 时间轮
type TimeWheel struct {
interval time.Duration
ticker *time.Ticker
buckets []*list.List
size int
cursor int
tasks sync.Map // key: *Task, value: bucket index
stop chan struct{}
}
type Task struct {
delay time.Duration
circle int
callback func()
}
// NewTimeWheel 创建一个时间轮
func NewTimeWheel(interval time.Duration, size int) *TimeWheel {
tw := &TimeWheel{
interval: interval,
size: size,
buckets: make([]*list.List, size),
stop: make(chan struct{}),
}
for i := 0; i < size; i++ {
tw.buckets[i] = list.New()
}
return tw
}
// Start 启动时间轮
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(tw.interval)
go func() {
for {
select {
case <-tw.ticker.C:
tw.tick()
case <-tw.stop:
tw.ticker.Stop()
return
}
}
}()
}
// Stop 停止时间轮
func (tw *TimeWheel) Stop() {
close(tw.stop)
}
func (tw *TimeWheel) tick() {
bucket := tw.buckets[tw.cursor]
tw.execute(bucket)
tw.cursor = (tw.cursor + 1) % tw.size
}
func (tw *TimeWheel) execute(bucket *list.List) {
for e := bucket.Front(); e != nil; {
task := e.Value.(*Task)
if task.circle > 0 {
task.circle--
e = e.Next()
continue
}
go task.callback()
next := e.Next()
bucket.Remove(e)
tw.tasks.Delete(task)
e = next
}
}
// AfterFunc 在延迟后执行函数
func (tw *TimeWheel) AfterFunc(delay time.Duration, callback func()) *Task {
if delay < tw.interval {
delay = tw.interval
}
task := &Task{
delay: delay,
callback: callback,
}
tw.addTask(task)
return task
}
func (tw *TimeWheel) addTask(task *Task) {
pos, circle := tw.getPositionAndCircle(task.delay)
task.circle = circle
tw.buckets[pos].PushBack(task)
tw.tasks.Store(task, pos)
}
func (tw *TimeWheel) getPositionAndCircle(delay time.Duration) (pos int, circle int) {
steps := int(delay / tw.interval)
pos = (tw.cursor + steps) % tw.size
circle = (steps - 1) / tw.size
if steps <= 0 {
pos = tw.cursor
circle = 0
}
return
}