package vpn

import (
	"encoding/json"
	"github.com/google/go-cmp/cmp"
	"github.com/net-byte/vtun/common/config"
	"github.com/net-byte/vtun/tun"
	_ "github.com/net-byte/vtun/tun"
	"github.com/net-byte/water"
	"github.com/openziti/sdk-golang/ziti"
	"github.com/openziti/sdk-golang/ziti/edge"
	"github.com/openziti/zrok/endpoints"
	cmap "github.com/orcaman/concurrent-map/v2"
	"github.com/pkg/errors"
	"github.com/sirupsen/logrus"
	"github.com/songgao/water/waterutil"
	"io"
	"net"
	"strconv"
	"sync/atomic"
	"time"
)

type BackendConfig struct {
	IdentityPath    string
	EndpointAddress string
	ShrToken        string
	RequestsChan    chan *endpoints.Request
}

type client struct {
	conn net.Conn
}

type Backend struct {
	cfg      *BackendConfig
	listener edge.Listener

	addr    net.IP
	addr6   net.IP
	subnet  *net.IPNet
	subnet6 *net.IPNet
	tun     *water.Interface
	mtu     int

	counter atomic.Uint32
	clients cmap.ConcurrentMap[dest, *client]
}

func NewBackend(cfg *BackendConfig) (*Backend, error) {

	options := ziti.ListenOptions{
		ConnectTimeout:               5 * time.Minute,
		WaitForNEstablishedListeners: 1,
	}
	zcfg, err := ziti.NewConfigFromFile(cfg.IdentityPath)
	if err != nil {
		return nil, errors.Wrap(err, "error loading config")
	}
	zctx, err := ziti.NewContext(zcfg)
	if err != nil {
		return nil, errors.Wrap(err, "error loading ziti context")
	}
	listener, err := zctx.ListenWithOptions(cfg.ShrToken, &options)
	if err != nil {
		return nil, errors.Wrap(err, "error listening")
	}

	addr6 := zrokIPv6Addr
	addr4 := zrokIPv4Addr
	sub4 := zrokIPv4
	sub6 := zrokIPv6

	if cfg.EndpointAddress != "" {
		addr4, sub4, err = net.ParseCIDR(cfg.EndpointAddress)
		if err != nil {
			return nil, errors.Wrap(err, "failed to parse VPN subnet config")
		}
	}

	b := &Backend{
		cfg:      cfg,
		listener: listener,
		mtu:      ZROK_VPN_MTU,
		clients: cmap.NewWithCustomShardingFunction[dest, *client](func(key dest) uint32 {
			return key.toInt32()
		}),
		addr:    addr4,
		addr6:   addr6,
		subnet:  sub4,
		subnet6: sub6,
	}
	b.counter.Store(1)
	return b, nil
}

func (b *Backend) readTun() {
	buf := make([]byte, ZROK_VPN_MTU)
	for {
		n, err := b.tun.Read(buf)
		if err != nil {
			logrus.WithError(err).Error("failed to read tun device")
			// handle? error
			panic(err)
			return
		}
		pkt := packet(buf[:n])
		if !waterutil.IsIPv4(pkt) {
			continue
		}

		logrus.WithField("packet", pkt).Trace("read from tun device")
		dest := pkt.destination()
		if clt, ok := b.clients.Get(dest); ok {
			_, err := clt.conn.Write(pkt)
			if err != nil {
				b.cfg.RequestsChan <- &endpoints.Request{
					Stamp:      time.Now(),
					RemoteAddr: dest.String(),
					Method:     "DISCONNECTED",
				}

				logrus.WithError(err).Errorf("failed to write packet to clt[%v]", dest)
				_ = clt.conn.Close()
				b.clients.Remove(dest)
			}
		} else {
			if b.subnet.Contains(net.IP(dest.addr[:])) {
				logrus.Errorf("no client with address[%v]", dest)
			}
		}
	}
}

func (b *Backend) Run() error {
	logrus.Info("started")
	defer logrus.Info("exited")

	bits, _ := b.subnet.Mask.Size()
	bits6, _ := b.subnet6.Mask.Size()

	tunCfg := config.Config{
		ServerIP:   b.addr.String(),
		ServerIPv6: b.addr6.String(),
		CIDR:       b.addr.String() + "/" + strconv.Itoa(bits),
		CIDRv6:     b.addr6.String() + "/" + strconv.Itoa(bits6),
		MTU:        ZROK_VPN_MTU,
		Verbose:    true,
	}
	logrus.Infof("%+v", tunCfg)
	b.tun = tun.CreateTun(tunCfg)
	defer func() {
		_ = b.tun.Close()
	}()

	go b.readTun()

	for {
		if conn, err := b.listener.Accept(); err == nil {
			go b.handle(conn)
		} else {
			return err
		}
	}
}

func (b *Backend) handle(conn net.Conn) {
	defer func(conn net.Conn) {
		_ = conn.Close()
	}(conn)

	ipv4, ipv6 := b.nextIP()
	ip := ipToDest(ipv4)

	bits, _ := b.subnet.Mask.Size()
	bits6, _ := b.subnet6.Mask.Size()

	cfg := &ClientConfig{
		Greeting:   "Welcome to zrok VPN",
		ServerIP:   b.addr.String(),
		ServerIPv6: b.addr6.String(),
		CIDR:       ipv4.String() + "/" + strconv.Itoa(bits),
		CIDR6:      ipv6.String() + "/" + strconv.Itoa(bits6),
		MTU:        b.mtu,
	}

	b.cfg.RequestsChan <- &endpoints.Request{
		Stamp:      time.Now(),
		RemoteAddr: ipv4.String(),
		Method:     "CONNECTED",
		Path:       cfg.ServerIP,
	}

	j, err := json.Marshal(&cfg)
	if err != nil {
		logrus.WithError(err).Error("failed to write client VPN config")
		return
	}
	_, err = conn.Write(j)
	if err != nil {
		logrus.WithError(err).Error("failed to write client VPN config")
		return
	}

	clt := &client{conn: conn}
	b.clients.Set(ip, clt)

	buf := make([]byte, b.mtu)
	for {
		read, err := conn.Read(buf)
		if err != nil {
			if err != io.EOF {
				logrus.WithError(err).Error("read error")
			}
			b.cfg.RequestsChan <- &endpoints.Request{
				Stamp:      time.Now(),
				RemoteAddr: ipv4.String(),
				Method:     "DISCONNECTED",
			}
			return
		}
		pkt := packet(buf[:read])
		logrus.WithField("packet", pkt).Trace("read from ziti")
		_, err = b.tun.Write(pkt)
		if err != nil {
			logrus.WithError(err).Error("failed to write packet to tun")
			return
		}
	}
}

func (b *Backend) nextIP() (net.IP, net.IP) {
	ip4 := make([]byte, len(b.subnet.IP))
	for {
		copy(ip4, b.subnet.IP)
		n := b.counter.Add(1)
		if n == 0 {
			continue
		}

		for i := 0; i < len(ip4); i++ {
			b := (n >> (i * 8)) % 0xff
			ip4[len(ip4)-1-i] ^= byte(b)
		}

		// subnet overflow
		if !b.subnet.Contains(ip4) {
			b.counter.Store(1)
			continue
		}

		if cmp.Equal(b.addr, ip4) {
			continue
		}

		if b.clients.Has(ipToDest(ip4)) {
			continue
		}

		break
	}

	ip6 := append([]byte{}, b.subnet6.IP...)
	copy(ip6[net.IPv6len-net.IPv4len:], ip4)

	return ip4, ip6
}