// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package vom

import (
	"io"
	"sync"

	"v.io/v23/vdl"
	"v.io/v23/verror"
)

var (
	errTypeInvalid        = verror.Register(pkgPath+".errTypeInvalid", verror.NoRetry, "{1:}{2:} vom: type {3} id {4} invalid, the min user type id is {5}{:_}")
	errAlreadyDefined     = verror.Register(pkgPath+".errAlreadyDefined", verror.NoRetry, "{1:}{2:} vom: type {3} id {4} already defined as {5}{:_}")
	errUnknownType        = verror.Register(pkgPath+".errUnknownType", verror.NoRetry, "{1:}{2:} vom: unknown type id {3}{:_}")
	errUnknownWireTypeDef = verror.Register(pkgPath+".errUnknownWireTypeDef", verror.NoRetry, "{1:}{2:} vom: unknown wire type definition {3}{:_}")
	errStartNotCalled     = verror.Register(pkgPath+".errStartNotCalled", verror.NoRetry, "{1:}{2:} vom: Start has not been called")
)

// TypeDecoder manages the receipt and unmarshalling of types from the other
// side of a connection.  Start must be called to start decoding types,
// and Stop must be called to reclaim resources.
type TypeDecoder struct {
	// The type encoder uses a 2-lock strategy for decoding. We use typeMu to lock
	// type definitions, and use buildMu to allow only one worker to build types at
	// a time. This is for simplifying the workflow and avoid unnecessary blocking
	// for type lookups.
	typeMu   sync.RWMutex
	idToType map[TypeId]*vdl.Type // GUARDED_BY(typeMu)

	buildMu   sync.Mutex
	buildCond *sync.Cond
	err       error               // GUARDED_BY(buildMu)
	idToWire  map[TypeId]wireType // GUARDED_BY(buildMu)
	dec       *xDecoder           // GUARDED_BY(buildMu)

	processingControlMu sync.Mutex
	goroutineRunning    bool // GUARDED_BY(processingControlMu)
	goroutineShouldStop bool // GUARDED_BY(processingControlMu)
}

// NewTypeDecoder returns a new TypeDecoder that reads from the given reader.
// The TypeDecoder understands all wire type formats generated by the TypeEncoder.
func NewTypeDecoder(r io.Reader) *TypeDecoder {
	return newTypeDecoderInternal(newDecbuf(r))
}

func newTypeDecoderInternal(buf *decbuf) *TypeDecoder {
	td := &TypeDecoder{
		idToType: make(map[TypeId]*vdl.Type),
		idToWire: make(map[TypeId]wireType),
		dec:      &xDecoder{old: &ZDecoder{buf: buf}},
	}
	td.buildCond = sync.NewCond(&td.buildMu)
	return td
}

func newDerivedTypeDecoderInternal(buf *decbuf, orig *TypeDecoder) *TypeDecoder {
	td := &TypeDecoder{
		idToType: orig.idToType,
		idToWire: orig.idToWire,
		dec:      &xDecoder{old: &ZDecoder{buf: buf}},
	}
	td.buildCond = sync.NewCond(&td.buildMu)
	return td
}

func (d *TypeDecoder) processLoop() {
	var err error
	for {
		d.processingControlMu.Lock()
		if d.goroutineShouldStop || err != nil {
			d.goroutineShouldStop = false
			d.goroutineRunning = false
			d.processingControlMu.Unlock()
			return
		}
		d.processingControlMu.Unlock()
		// Note that we will block indefinitely if the underlying
		// read blocks on the io.Reader.
		err = d.readSingleType()
		d.buildMu.Lock()
		d.err = err
		d.buildCond.Broadcast()
		d.buildMu.Unlock()
		// TODO(toddw): Reconsider d.err and d.buildCond strategy.
	}
}

// Start must be called to start decoding types.
func (d *TypeDecoder) Start() {
	d.processingControlMu.Lock()
	d.goroutineShouldStop = false
	if !d.goroutineRunning {
		d.goroutineRunning = true
		go d.processLoop()
	}
	d.processingControlMu.Unlock()
}

// Stop must be called after Start, to stop decoding types
// and reclaim resources.  Once Stop is called,
// subsequent Decode calls on Decoders initialized with d
// will return errors.
func (d *TypeDecoder) Stop() {
	d.processingControlMu.Lock()
	d.goroutineShouldStop = true
	d.processingControlMu.Unlock()
}

// readSingleType reads a single wire type
func (d *TypeDecoder) readSingleType() error {
	var wt wireType
	curTypeID, err := d.dec.decodeWireType(&wt)
	if err != nil {
		return err
	}

	// Add the wire type.
	if err := d.addWireType(curTypeID, wt); err != nil {
		return err
	}

	if !d.dec.old.typeIncomplete {
		if err := d.buildType(curTypeID); d.dec.old.buf.version >= Version81 && err != nil {
			return err
		}
	}

	return nil
}

// LookupType returns the type for tid. If the type is not yet available,
// this will wait until it arrives and is built.
func (d *TypeDecoder) lookupType(tid TypeId) (*vdl.Type, error) {
	if tt := d.lookupKnownType(tid); tt != nil {
		return tt, nil
	}

	d.buildMu.Lock()
	defer d.buildMu.Unlock()
	for {
		if d.err != nil && d.err != io.EOF {
			// Return any existing error immediately. Skip EOF because it
			// may still be possible to lookup a type.
			return nil, d.err
		}

		if tt := d.lookupKnownType(tid); tt != nil {
			return tt, nil
		}

		if d.err != nil {
			return nil, d.err
		}

		d.processingControlMu.Lock()
		running := d.goroutineRunning
		d.processingControlMu.Unlock()
		if !running {
			return nil, verror.New(errStartNotCalled, nil)
		}

		d.buildCond.Wait()
	}
}

// addWireType adds the wire type wt with the type id tid.
func (d *TypeDecoder) addWireType(tid TypeId, wt wireType) error {
	d.buildMu.Lock()
	err := d.addWireTypeBuildLocked(tid, wt)
	d.buildMu.Unlock()
	return err
}

func (d *TypeDecoder) addWireTypeBuildLocked(tid TypeId, wt wireType) error {
	if tid < WireIdFirstUserType {
		return verror.New(errTypeInvalid, nil, wt, tid, WireIdFirstUserType)
	}
	// TODO(toddw): Allow duplicates according to some heuristic (e.g. only
	// identical, or only if the later one is a "superset", etc).
	if dup := d.lookupKnownType(tid); dup != nil {
		return verror.New(errAlreadyDefined, nil, wt, tid, dup)
	}
	if dup := d.idToWire[tid]; dup != nil {
		return verror.New(errAlreadyDefined, nil, wt, tid, dup)
	}
	d.idToWire[tid] = wt
	return nil
}

func (d *TypeDecoder) lookupKnownType(tid TypeId) *vdl.Type {
	if tt := bootstrapIdToType[tid]; tt != nil {
		return tt
	}
	d.typeMu.RLock()
	tt := d.idToType[tid]
	d.typeMu.RUnlock()
	return tt
}

// buildType builds the type from the given wire type.
func (d *TypeDecoder) buildType(tid TypeId) error {
	builder := vdl.TypeBuilder{}
	pending := make(map[TypeId]vdl.PendingType)
	_, err := d.makeType(tid, &builder, pending)
	if err != nil {
		return err
	}
	builder.Build()
	types := make(map[TypeId]*vdl.Type)
	for tid, pt := range pending {
		tt, err := pt.Built()
		if err != nil {
			return err
		}
		types[tid] = tt
	}
	// Add the types to idToType map.
	d.typeMu.Lock()
	for tid, tt := range types {
		delete(d.idToWire, tid)
		d.idToType[tid] = tt
	}
	d.typeMu.Unlock()
	return nil
}

// makeType makes the pending type from its wire type representation.
func (d *TypeDecoder) makeType(tid TypeId, builder *vdl.TypeBuilder, pending map[TypeId]vdl.PendingType) (vdl.PendingType, error) {
	wt := d.idToWire[tid]
	if wt == nil {
		return nil, verror.New(errUnknownType, nil, tid)
	}
	// Make the type from its wireType representation.  Both named and unnamed
	// types may be recursive, so we must populate pending before subsequent
	// recursive lookups.  Eventually the built type will be added to dt.idToType.
	if name := wt.(wireTypeGeneric).TypeName(); name != "" {
		namedType := builder.Named(name)
		pending[tid] = namedType
		if wtNamed, ok := wt.(wireTypeNamedT); ok {
			// This is a wireNamed pointing at a base type.
			baseType, err := d.lookupOrMakeType(wtNamed.Value.Base, builder, pending)
			if err != nil {
				return nil, err
			}
			namedType.AssignBase(baseType)
			return namedType, nil
		}
		// This isn't wireNamed, but has a non-empty name.
		baseType, err := d.startBaseType(wt, builder)
		if err != nil {
			return nil, err
		}
		if err := d.finishBaseType(wt, baseType, builder, pending); err != nil {
			return nil, err
		}
		namedType.AssignBase(baseType)
		return namedType, nil
	}
	// We make unnamed types in two stages, to ensure that we populate pending
	// before any recursive lookups.
	unnamedType, err := d.startBaseType(wt, builder)
	if err != nil {
		return nil, err
	}
	pending[tid] = unnamedType
	if err := d.finishBaseType(wt, unnamedType, builder, pending); err != nil {
		return nil, err
	}
	return unnamedType, nil
}

func (d *TypeDecoder) startBaseType(wt wireType, builder *vdl.TypeBuilder) (vdl.PendingType, error) {
	switch wt := wt.(type) {
	case wireTypeEnumT:
		return builder.Enum(), nil
	case wireTypeArrayT:
		return builder.Array(), nil
	case wireTypeListT:
		return builder.List(), nil
	case wireTypeSetT:
		return builder.Set(), nil
	case wireTypeMapT:
		return builder.Map(), nil
	case wireTypeStructT:
		return builder.Struct(), nil
	case wireTypeUnionT:
		return builder.Union(), nil
	case wireTypeOptionalT:
		return builder.Optional(), nil
	default:
		return nil, verror.New(errUnknownWireTypeDef, nil, wt)
	}
}

func (d *TypeDecoder) finishBaseType(wt wireType, p vdl.PendingType, builder *vdl.TypeBuilder, pending map[TypeId]vdl.PendingType) error {
	switch wt := wt.(type) {
	case wireTypeEnumT:
		for _, label := range wt.Value.Labels {
			p.(vdl.PendingEnum).AppendLabel(label)
		}
	case wireTypeArrayT:
		elemType, err := d.lookupOrMakeType(wt.Value.Elem, builder, pending)
		if err != nil {
			return err
		}
		p.(vdl.PendingArray).AssignElem(elemType).AssignLen(int(wt.Value.Len))
	case wireTypeListT:
		elemType, err := d.lookupOrMakeType(wt.Value.Elem, builder, pending)
		if err != nil {
			return err
		}
		p.(vdl.PendingList).AssignElem(elemType)
	case wireTypeSetT:
		keyType, err := d.lookupOrMakeType(wt.Value.Key, builder, pending)
		if err != nil {
			return err
		}
		p.(vdl.PendingSet).AssignKey(keyType)
	case wireTypeMapT:
		keyType, err := d.lookupOrMakeType(wt.Value.Key, builder, pending)
		if err != nil {
			return err
		}
		elemType, err := d.lookupOrMakeType(wt.Value.Elem, builder, pending)
		if err != nil {
			return err
		}
		p.(vdl.PendingMap).AssignKey(keyType).AssignElem(elemType)
	case wireTypeStructT:
		for _, field := range wt.Value.Fields {
			fieldType, err := d.lookupOrMakeType(field.Type, builder, pending)
			if err != nil {
				return err
			}
			p.(vdl.PendingStruct).AppendField(field.Name, fieldType)
		}
	case wireTypeUnionT:
		for _, field := range wt.Value.Fields {
			fieldType, err := d.lookupOrMakeType(field.Type, builder, pending)
			if err != nil {
				return err
			}
			p.(vdl.PendingUnion).AppendField(field.Name, fieldType)
		}
	case wireTypeOptionalT:
		elemType, err := d.lookupOrMakeType(wt.Value.Elem, builder, pending)
		if err != nil {
			return err
		}
		p.(vdl.PendingOptional).AssignElem(elemType)
	}
	return nil
}

func (d *TypeDecoder) lookupOrMakeType(tid TypeId, builder *vdl.TypeBuilder, pending map[TypeId]vdl.PendingType) (vdl.TypeOrPending, error) {
	if tt := d.lookupKnownType(tid); tt != nil {
		return tt, nil
	}
	if p, ok := pending[tid]; ok {
		return p, nil
	}
	return d.makeType(tid, builder, pending)
}
