package api

import (
	"context"
	"fmt"
	"reflect"
	"sync"

	"github.com/alecthomas/units"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/common/model"

	"github.com/grafana/alloy/internal/component"
	"github.com/grafana/alloy/internal/component/common/loki"
	fnet "github.com/grafana/alloy/internal/component/common/net"
	"github.com/grafana/alloy/internal/component/common/relabel"
	"github.com/grafana/alloy/internal/component/loki/source"
	"github.com/grafana/alloy/internal/component/loki/source/api/internal/lokipush"
	"github.com/grafana/alloy/internal/featuregate"
	"github.com/grafana/alloy/internal/util"
)

func init() {
	component.Register(component.Registration{
		Name:      "loki.source.api",
		Stability: featuregate.StabilityGenerallyAvailable,
		Args:      Arguments{},
		Build: func(opts component.Options, args component.Arguments) (component.Component, error) {
			return New(opts, args.(Arguments))
		},
	})
}

type Arguments struct {
	Server               *fnet.ServerConfig  `alloy:",squash"`
	ForwardTo            []loki.LogsReceiver `alloy:"forward_to,attr"`
	Labels               map[string]string   `alloy:"labels,attr,optional"`
	RelabelRules         relabel.Rules       `alloy:"relabel_rules,attr,optional"`
	UseIncomingTimestamp bool                `alloy:"use_incoming_timestamp,attr,optional"`
	MaxSendMessageSize   units.Base2Bytes    `alloy:"max_send_message_size,attr,optional"`
}

// SetToDefault implements syntax.Defaulter.
func (a *Arguments) SetToDefault() {
	*a = Arguments{
		Server:             fnet.DefaultServerConfig(),
		MaxSendMessageSize: 100 * units.MiB,
	}
}

func (a *Arguments) labelSet() model.LabelSet {
	labelSet := make(model.LabelSet, len(a.Labels))
	for k, v := range a.Labels {
		labelSet[model.LabelName(k)] = model.LabelValue(v)
	}
	return labelSet
}

type Component struct {
	opts               component.Options
	handler            loki.LogsBatchReceiver
	uncheckedCollector *util.UncheckedCollector

	serverMut sync.Mutex
	server    *lokipush.PushAPIServer

	fanout *loki.Fanout
}

func New(opts component.Options, args Arguments) (*Component, error) {
	c := &Component{
		opts:               opts,
		handler:            loki.NewLogsBatchReceiver(),
		uncheckedCollector: util.NewUncheckedCollector(nil),

		fanout: loki.NewFanout(args.ForwardTo),
	}
	opts.Registerer.MustRegister(c.uncheckedCollector)
	err := c.Update(args)
	if err != nil {
		return nil, err
	}
	return c, nil
}

func (c *Component) Run(ctx context.Context) (err error) {
	defer func() {
		// NOTE: We don't have to drain here because we force cancel all in-flight request.
		c.serverMut.Lock()
		defer c.serverMut.Unlock()
		if c.server != nil {
			c.server.ForceShutdown()
			c.server = nil
		}
	}()

	source.ConsumeBatch(ctx, c.handler, c.fanout)
	return
}

func (c *Component) Update(args component.Arguments) error {
	newArgs, ok := args.(Arguments)
	if !ok {
		return fmt.Errorf("invalid type of arguments: %T", args)
	}

	// if no server config provided, we'll use defaults
	if newArgs.Server == nil {
		newArgs.Server = &fnet.ServerConfig{}
	}
	// to avoid port conflicts, if no GRPC is configured, make sure we use a random port
	// also, use localhost IP, so we don't require root to run.
	if newArgs.Server.GRPC == nil {
		newArgs.Server.GRPC = &fnet.GRPCConfig{
			ListenPort:    0,
			ListenAddress: "127.0.0.1",
		}
	}

	c.fanout.UpdateChildren(newArgs.ForwardTo)

	c.serverMut.Lock()
	defer c.serverMut.Unlock()
	serverNeedsRestarting := c.server == nil || !reflect.DeepEqual(c.server.ServerConfig(), *newArgs.Server)
	if serverNeedsRestarting {
		if c.server != nil {
			c.server.Shutdown()
		}

		// [server.Server] registers new metrics every time it is created. To
		// avoid issues with re-registering metrics with the same name, we create a
		// new registry for the server every time we create one, and pass it to an
		// unchecked collector to bypass uniqueness checking.
		serverRegistry := prometheus.NewRegistry()
		c.uncheckedCollector.SetCollector(serverRegistry)

		var err error
		c.server, err = lokipush.NewPushAPIServer(c.opts.Logger, newArgs.Server, c.handler, serverRegistry, int64(newArgs.MaxSendMessageSize))
		if err != nil {
			return fmt.Errorf("failed to create embedded server: %v", err)
		}
		err = c.server.Run()
		if err != nil {
			return fmt.Errorf("failed to run embedded server: %v", err)
		}
	}

	c.server.SetLabels(newArgs.labelSet())
	c.server.SetRelabelRules(newArgs.RelabelRules)
	c.server.SetKeepTimestamp(newArgs.UseIncomingTimestamp)

	return nil
}
