Chore: use generics as possible

This commit is contained in:
yaling888
2022-04-24 02:07:57 +08:00
committed by adlyq
parent dee1aeb6c3
commit 4fd7d0f707
29 changed files with 500 additions and 267 deletions

View File

@ -5,10 +5,10 @@ import (
"sync"
)
type Option = func(b *Batch)
type Option[T any] func(b *Batch[T])
type Result struct {
Value any
type Result[T any] struct {
Value T
Err error
}
@ -17,8 +17,8 @@ type Error struct {
Err error
}
func WithConcurrencyNum(n int) Option {
return func(b *Batch) {
func WithConcurrencyNum[T any](n int) Option[T] {
return func(b *Batch[T]) {
q := make(chan struct{}, n)
for i := 0; i < n; i++ {
q <- struct{}{}
@ -28,8 +28,8 @@ func WithConcurrencyNum(n int) Option {
}
// Batch similar to errgroup, but can control the maximum number of concurrent
type Batch struct {
result map[string]Result
type Batch[T any] struct {
result map[string]Result[T]
queue chan struct{}
wg sync.WaitGroup
mux sync.Mutex
@ -38,7 +38,7 @@ type Batch struct {
cancel func()
}
func (b *Batch) Go(key string, fn func() (any, error)) {
func (b *Batch[T]) Go(key string, fn func() (T, error)) {
b.wg.Add(1)
go func() {
defer b.wg.Done()
@ -59,14 +59,14 @@ func (b *Batch) Go(key string, fn func() (any, error)) {
})
}
ret := Result{value, err}
ret := Result[T]{value, err}
b.mux.Lock()
defer b.mux.Unlock()
b.result[key] = ret
}()
}
func (b *Batch) Wait() *Error {
func (b *Batch[T]) Wait() *Error {
b.wg.Wait()
if b.cancel != nil {
b.cancel()
@ -74,26 +74,26 @@ func (b *Batch) Wait() *Error {
return b.err
}
func (b *Batch) WaitAndGetResult() (map[string]Result, *Error) {
func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) {
err := b.Wait()
return b.Result(), err
}
func (b *Batch) Result() map[string]Result {
func (b *Batch[T]) Result() map[string]Result[T] {
b.mux.Lock()
defer b.mux.Unlock()
copy := map[string]Result{}
copyM := map[string]Result[T]{}
for k, v := range b.result {
copy[k] = v
copyM[k] = v
}
return copy
return copyM
}
func New(ctx context.Context, opts ...Option) (*Batch, context.Context) {
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
ctx, cancel := context.WithCancel(ctx)
b := &Batch{
result: map[string]Result{},
b := &Batch[T]{
result: map[string]Result[T]{},
}
for _, o := range opts {

View File

@ -11,14 +11,14 @@ import (
)
func TestBatch(t *testing.T) {
b, _ := New(context.Background())
b, _ := New[string](context.Background())
now := time.Now()
b.Go("foo", func() (any, error) {
b.Go("foo", func() (string, error) {
time.Sleep(time.Millisecond * 100)
return "foo", nil
})
b.Go("bar", func() (any, error) {
b.Go("bar", func() (string, error) {
time.Sleep(time.Millisecond * 150)
return "bar", nil
})
@ -32,20 +32,20 @@ func TestBatch(t *testing.T) {
for k, v := range result {
assert.NoError(t, v.Err)
assert.Equal(t, k, v.Value.(string))
assert.Equal(t, k, v.Value)
}
}
func TestBatchWithConcurrencyNum(t *testing.T) {
b, _ := New(
b, _ := New[string](
context.Background(),
WithConcurrencyNum(3),
WithConcurrencyNum[string](3),
)
now := time.Now()
for i := 0; i < 7; i++ {
idx := i
b.Go(strconv.Itoa(idx), func() (any, error) {
b.Go(strconv.Itoa(idx), func() (string, error) {
time.Sleep(time.Millisecond * 100)
return strconv.Itoa(idx), nil
})
@ -57,21 +57,21 @@ func TestBatchWithConcurrencyNum(t *testing.T) {
for k, v := range result {
assert.NoError(t, v.Err)
assert.Equal(t, k, v.Value.(string))
assert.Equal(t, k, v.Value)
}
}
func TestBatchContext(t *testing.T) {
b, ctx := New(context.Background())
b, ctx := New[string](context.Background())
b.Go("error", func() (any, error) {
b.Go("error", func() (string, error) {
time.Sleep(time.Millisecond * 100)
return nil, errors.New("test error")
return "", errors.New("test error")
})
b.Go("ctx", func() (any, error) {
b.Go("ctx", func() (string, error) {
<-ctx.Done()
return nil, ctx.Err()
return "", ctx.Err()
})
result, err := b.WaitAndGetResult()