// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package unixfd provides provides support for Dialing and Listening
// on already connected file descriptors (like those returned by socketpair).
package unixfd

import (
	"errors"
	"fmt"
	"io"
	"net"
	"os"
	"strconv"
	"sync"
	"syscall"
	"time"
	"unsafe"

	"v.io/v23/rpc"
)

const Network string = "unixfd"

func init() {
	rpc.RegisterProtocol(Network, unixFDConn, unixFDListen)
}

// singleConnListener implements net.Listener for an already-connected socket.
// This is different from net.FileListener, which calls syscall.Listen
// on an unconnected socket.
type singleConnListener struct {
	c    chan net.Conn
	addr net.Addr
	sync.Mutex
}

func (l *singleConnListener) getChan() chan net.Conn {
	l.Lock()
	defer l.Unlock()
	return l.c
}

func (l *singleConnListener) Accept() (net.Conn, error) {
	c := l.getChan()
	if c == nil {
		return nil, errors.New("listener closed")
	}
	if conn, ok := <-c; ok {
		return conn, nil
	}
	return nil, io.EOF
}

func (l *singleConnListener) Close() error {
	l.Lock()
	defer l.Unlock()
	lc := l.c
	if lc == nil {
		return errors.New("listener already closed")
	}
	close(l.c)
	l.c = nil
	// If the socket was never Accept'ed we need to close it.
	if c, ok := <-lc; ok {
		return c.Close()
	}
	return nil
}

func (l *singleConnListener) Addr() net.Addr {
	return l.addr
}

func unixFDConn(protocol, address string, timeout time.Duration) (net.Conn, error) {
	// TODO(cnicolaou): have this respect the timeout. Possibly have a helper
	// function that can be generally used for this, but in practice, I think
	// it'll be cleaner to use the underlying protocol's deadline support of it
	// has it.
	fd, err := strconv.ParseInt(address, 10, 32)
	if err != nil {
		return nil, err
	}
	file := os.NewFile(uintptr(fd), "tmp")
	conn, err := net.FileConn(file)
	// 'file' is not used after this point, but we keep it open
	// so that 'address' remains valid.
	if err != nil {
		file.Close()
		return nil, err
	}
	// We wrap 'conn' so we can customize the address, and also
	// to close 'file'.
	return &fdConn{addr: addr(address), sock: file, Conn: conn}, nil
}

type fdConn struct {
	addr net.Addr
	sock *os.File
	net.Conn

	mu     sync.Mutex
	closed bool
}

func (c *fdConn) Close() (err error) {
	c.mu.Lock()
	defer c.mu.Unlock()

	if c.closed {
		return nil
	}

	c.closed = true
	defer c.sock.Close()
	return c.Conn.Close()
}

func (c *fdConn) LocalAddr() net.Addr {
	return c.addr
}

func (c *fdConn) RemoteAddr() net.Addr {
	return c.addr
}

func unixFDListen(protocol, address string) (net.Listener, error) {
	conn, err := unixFDConn(protocol, address, 0)
	if err != nil {
		return nil, err
	}
	c := make(chan net.Conn, 1)
	c <- conn
	return &singleConnListener{c, conn.LocalAddr(), sync.Mutex{}}, nil
}

type addr string

func (a addr) Network() string { return Network }
func (a addr) String() string  { return string(a) }

// Addr returns a net.Addr for the unixfd network for the given file descriptor.
func Addr(fd uintptr) net.Addr {
	return addr(fmt.Sprintf("%d", fd))
}

type fileDescriptor struct {
	fd   chan int
	name string
}

func newFd(fd int, name string) *fileDescriptor {
	ch := make(chan int, 1)
	ch <- fd
	close(ch)
	d := &fileDescriptor{ch, name}
	return d
}

func (f *fileDescriptor) releaseAddr() net.Addr {
	if fd, ok := <-f.fd; ok {
		return Addr(uintptr(fd))
	}
	return nil
}

func (f *fileDescriptor) releaseFile() *os.File {
	if fd, ok := <-f.fd; ok {
		return os.NewFile(uintptr(fd), f.name)
	}
	return nil
}

// maybeClose closes the file descriptor, if it hasn't been released.
func (f *fileDescriptor) maybeClose() {
	if file := f.releaseFile(); file != nil {
		file.Close()
	}
}

// Socketpair returns a pair of connected sockets for communicating with a child process.
func Socketpair() (*net.UnixConn, *os.File, error) {
	lfd, rfd, err := socketpair()
	if err != nil {
		return nil, nil, err
	}
	defer rfd.maybeClose()
	file := lfd.releaseFile()
	// FileConn dups the fd, so we still want to close the original one.
	defer file.Close()
	conn, err := net.FileConn(file)
	if err != nil {
		return nil, nil, err
	}
	return conn.(*net.UnixConn), rfd.releaseFile(), nil
}

func socketpair() (local, remote *fileDescriptor, err error) {
	syscall.ForkLock.RLock()
	fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
	if err == nil {
		syscall.CloseOnExec(fds[0])
		syscall.CloseOnExec(fds[1])
	}
	syscall.ForkLock.RUnlock()
	if err != nil {
		return nil, nil, err
	}
	return newFd(fds[0], "local"), newFd(fds[1], "remote"), nil
}

// SendConnection creates a new connected socket and sends
// one end over 'conn', along with 'data'. It returns the address for
// the local end of the socketpair.
// Note that the returned address is an open file descriptor,
// which you must close if you do not Dial or Listen to the address.
func SendConnection(conn *net.UnixConn, data []byte) (addr net.Addr, err error) {
	if len(data) < 1 {
		return nil, errors.New("cannot send a socket without data.")
	}
	remote, local, err := socketpair()
	if err != nil {
		return nil, err
	}
	defer local.maybeClose()
	rfile := remote.releaseFile()

	rights := syscall.UnixRights(int(rfile.Fd()))
	n, oobn, err := conn.WriteMsgUnix(data, rights, nil)
	if err != nil {
		rfile.Close()
		return nil, err
	} else if n != len(data) || oobn != len(rights) {
		rfile.Close()
		return nil, fmt.Errorf("expected to send %d, %d bytes,  sent %d, %d", len(data), len(rights), n, oobn)
	}
	// Wait for the other side to acknowledge.
	// This is to work around a race on OS X where it appears we can close
	// the file descriptor before it gets transfered over the socket.
	f := local.releaseFile()
	syscall.ForkLock.Lock()
	fd, err := syscall.Dup(int(f.Fd()))
	if err != nil {
		syscall.ForkLock.Unlock()
		f.Close()
		rfile.Close()
		return nil, err
	}
	syscall.CloseOnExec(fd)
	syscall.ForkLock.Unlock()
	newConn, err := net.FileConn(f)
	f.Close()
	if err != nil {
		rfile.Close()
		return nil, err
	}
	newConn.Read(make([]byte, 1))
	newConn.Close()
	rfile.Close()

	return Addr(uintptr(fd)), nil
}

const cmsgDataLength = int(unsafe.Sizeof(int(1)))

// ReadConnection reads a connection and additional data sent on 'conn' via a call to SendConnection.
// 'buf' must be large enough to hold the data.
// The returned function must be called when you are ready for the other side
// to start sending data, but before writing anything to the connection.
// If there is an error you must still call the function before closing the connection.
func ReadConnection(conn *net.UnixConn, buf []byte) (net.Addr, int, func(), error) {
	oob := make([]byte, syscall.CmsgLen(cmsgDataLength))
	n, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
	if err != nil {
		return nil, n, nil, err
	}
	if oobn > len(oob) {
		return nil, n, nil, fmt.Errorf("received too large oob data (%d, max %d)", oobn, len(oob))
	}
	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, n, nil, err
	}
	fd := -1
	// Loop through any file descriptors we are sent, and close
	// all extras.
	for _, scm := range scms {
		fds, err := syscall.ParseUnixRights(&scm)
		if err != nil {
			return nil, n, nil, err
		}
		for _, f := range fds {
			if fd == -1 {
				fd = f
			} else if f != -1 {
				syscall.Close(f)
			}
		}
	}
	if fd == -1 {
		return nil, n, nil, nil
	}
	result := Addr(uintptr(fd))
	syscall.ForkLock.Lock()
	fd, err = syscall.Dup(fd)
	if err != nil {
		syscall.ForkLock.Unlock()
		CloseUnixAddr(result)
		return nil, n, nil, err
	}
	syscall.CloseOnExec(fd)
	syscall.ForkLock.Unlock()
	file := os.NewFile(uintptr(fd), "newconn")
	newconn, err := net.FileConn(file)
	file.Close()
	if err != nil {
		CloseUnixAddr(result)
		return nil, n, nil, err
	}
	return result, n, func() {
		newconn.Write(make([]byte, 1))
		newconn.Close()
	}, nil
}

func CloseUnixAddr(addr net.Addr) error {
	if addr.Network() != Network {
		return errors.New("invalid network")
	}
	fd, err := strconv.ParseInt(addr.String(), 10, 32)
	if err != nil {
		return err
	}
	return syscall.Close(int(fd))
}
