package structr

import (
	"context"
	"fmt"
	"sync/atomic"
	"testing"
	"time"
)

func TestQueueCtx(t *testing.T) {
	t.Run("structA", func(t *testing.T) { testQueueCtx(t, testStructA) })
	t.Run("structB", func(t *testing.T) { testQueueCtx(t, testStructB) })
	t.Run("structC", func(t *testing.T) { testQueueCtx(t, testStructC) })
}

func testQueueCtx[T any](t *testing.T, test test[T]) {
	var q QueueCtx[*T]

	// Initialize the struct queue.
	q.Init(QueueConfig[*T]{
		Indices: test.indices,
	})

	// Check that fake indices cause panic
	for _, index := range test.indices {
		fake := index.Fields + "!"
		catchpanic(t, func() {
			q.Index(fake)
		}, "unknown index: "+fake)
	}

	// Check that wrong
	// index causes panic
	catchpanic(t, func() {
		wrong := new(Index)
		q.Pop(wrong)
	}, "invalid index for queue")

	// Push all values to front.
	t.Logf("PushFront: %v", test.values)
	q.PushFront(test.values...)

	// Ensure queue length of expected size.
	if l := q.Len(); l != len(test.values) {
		t.Fatalf("queue not of expected length: have=%d want=%d", l, len(test.values))
	}

	// Check values in expected order.
	for _, value := range test.values {
		check, ok := q.PopBack(context.TODO())
		t.Logf("PopBack: %+v", check)
		if !ok || !test.equalfn(value, check) {
			t.Fatalf("value not at expected location: value=%+v check=%+v", value, check)
		}
	}

	// Check that queue is empty.
	if l := q.Len(); l != 0 {
		t.Fatalf("queue should be empty: was=%d", l)
	}

	// Push all values to back.
	t.Logf("PushBack: %v", test.values)
	q.PushBack(test.values...)

	// Ensure queue length of expected size.
	if l := q.Len(); l != len(test.values) {
		t.Fatalf("queue not of expected length: have=%d want=%d", l, len(test.values))
	}

	// Check values in expected order.
	for _, value := range test.values {
		check, ok := q.PopFront(context.TODO())
		t.Logf("PopFront: %+v", check)
		if !ok || !test.equalfn(value, check) {
			t.Fatalf("value not at expected location: value=%+v check=%+v", value, check)
		}
	}

	// Check that queue is empty.
	if l := q.Len(); l != 0 {
		t.Fatalf("queue should be empty: was=%d", l)
	}

	// Push all values to back.
	t.Logf("PushFront: %v", test.values)
	q.PushFront(test.values...)

	// Pop each of the values from the queue
	// by their indexed key. It's easier to just
	// iterate through all the values for all indices
	// instead of getting particular about which
	// value is stored in which particular index.
	for _, index := range test.indices {
		var keys []Key

		// Get associated structr index.
		idx := q.Index(index.Fields)

		for _, value := range test.values {
			// extract key parts for value.
			parts, ok := indexkey(idx, value)
			if !ok {
				continue
			}

			// generate key from parts.
			key := idx.Key(parts...)

			// add index key to keys.
			keys = append(keys, key)
		}

		// Pop all keys in index.
		t.Logf("Pop: %s %v", index.Fields, keys)
		_ = q.Pop(idx, keys...)
	}

	// Prepare test context to block against.
	ctx, cncl := context.WithCancel(context.TODO())
	defer cncl()

	var rcvd int32

	go func() {
		for {
			// Keep popping + incrementing
			// until test context canceled.
			_, ok := q.PopFront(ctx)
			if !ok {
				if ctx.Err() == nil {
					panic("returned no value without ctx cancel")
				}
				return
			}
			atomic.AddInt32(&rcvd, 1)
		}
	}()

	go func() {
		for {
			// Keep popping + incrementing
			// until test context canceled.
			_, ok := q.PopBack(ctx)
			if !ok {
				if ctx.Err() == nil {
					panic("returned no value without ctx cancel")
				}
				return
			}
			atomic.AddInt32(&rcvd, 1)
		}
	}()

	var sent int32

	go func() {
		for ctx.Err() == nil {
			// Keep pushing + incrementing
			// until send count reaches max.
			q.PushFront(test.values...)
			atomic.AddInt32(&sent, int32(len(test.values)))
		}
	}()

	go func() {
		for ctx.Err() == nil {
			// Keep pushing + incrementing
			// until context is cancelled.
			q.PushBack(test.values...)
			atomic.AddInt32(&sent, int32(len(test.values)))
		}
	}()

	// Give goroutines some
	// time to send + receive.
	time.Sleep(time.Second)
	cncl()

	// Wait for goroutines to
	// finish their receives.
	time.Sleep(time.Second)

	// Check that final counts match.
	sent2 := atomic.LoadInt32(&sent)
	rcvd2 := atomic.LoadInt32(&rcvd)
	t.Logf("sent=%d rcvd=%d", sent2, rcvd2)
	if sent2 != rcvd2 {
		t.Fatal("sent and received did not match")
	}

	// print final debug.
	fmt.Println(q.Debug())
}
