// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package command

import (
	"context"
	"errors"
	"os"
	"os/signal"
	"syscall"

	"github.com/spf13/cobra"

	clientv3 "go.etcd.io/etcd/client/v3"
	"go.etcd.io/etcd/client/v3/concurrency"
	"go.etcd.io/etcd/pkg/v3/cobrautl"
)

var electListen bool

// NewElectCommand returns the cobra command for "elect".
func NewElectCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:     "elect <election-name> [proposal]",
		Short:   "Observes and participates in leader election",
		Run:     electCommandFunc,
		GroupID: groupConcurrencyID,
	}
	cmd.Flags().BoolVarP(&electListen, "listen", "l", false, "observation mode")
	return cmd
}

func electCommandFunc(cmd *cobra.Command, args []string) {
	if len(args) != 1 && len(args) != 2 {
		cobrautl.ExitWithError(cobrautl.ExitBadArgs, errors.New("elect takes one election name argument and an optional proposal argument"))
	}
	c := mustClientFromCmd(cmd)

	var err error
	if len(args) == 1 {
		if !electListen {
			cobrautl.ExitWithError(cobrautl.ExitBadArgs, errors.New("no proposal argument but -l not set"))
		}
		err = observe(c, args[0])
	} else {
		if electListen {
			cobrautl.ExitWithError(cobrautl.ExitBadArgs, errors.New("proposal given but -l is set"))
		}
		err = campaign(c, args[0], args[1])
	}
	if err != nil {
		cobrautl.ExitWithError(cobrautl.ExitError, err)
	}
}

func observe(c *clientv3.Client, election string) error {
	s, err := concurrency.NewSession(c)
	if err != nil {
		return err
	}
	e := concurrency.NewElection(s, election)
	ctx, cancel := context.WithCancel(context.TODO())

	donec := make(chan struct{})
	sigc := make(chan os.Signal, 1)
	signal.Notify(sigc, syscall.SIGINT, syscall.SIGTERM)
	go func() {
		<-sigc
		cancel()
	}()

	go func() {
		for resp := range e.Observe(ctx) {
			display.Get(resp)
		}
		close(donec)
	}()

	<-donec

	select {
	case <-ctx.Done():
	default:
		return errors.New("elect: observer lost")
	}

	return nil
}

func campaign(c *clientv3.Client, election string, prop string) error {
	s, err := concurrency.NewSession(c)
	if err != nil {
		return err
	}
	e := concurrency.NewElection(s, election)
	ctx, cancel := context.WithCancel(context.TODO())

	donec := make(chan struct{})
	sigc := make(chan os.Signal, 1)
	signal.Notify(sigc, syscall.SIGINT, syscall.SIGTERM)
	go func() {
		<-sigc
		cancel()
		close(donec)
	}()

	if err = e.Campaign(ctx, prop); err != nil {
		return err
	}

	// print key since elected
	resp, err := c.Get(ctx, e.Key())
	if err != nil {
		return err
	}
	display.Get(*resp)

	select {
	case <-donec:
	case <-s.Done():
		return errors.New("elect: session expired")
	}

	return e.Resign(context.TODO())
}
