package phrasestream

import (
	"bytes"
	"io"
	"sync"
)

type Phrasestream struct {
	w  io.Writer
	mu sync.Mutex

	rad      *tree
	n        *node
	last     *node
	matching int

	buf []byte

	callback callback
	cbbuf    []byte
}

// Callback provides the underlying writer, the phrase that triggered the
// callback, and any data after the phrase was matched based on the stop
// function provided.
type Callback func(w io.Writer, phrase string, buf []byte) error

// Stop is a function provided with a phrase callback that collects any data
// after a phrase is matched until the stop function returns true.
type Stop func(byte) bool

type callback struct {
	phrase string
	fn     Callback
	stop   func(byte) bool
}

// New returns a new Phrasestream.
func New(w io.Writer) *Phrasestream {
	ps := &Phrasestream{}
	ps.Reset(w)
	return ps
}

// Reset resets the Phrasestream.
func (ps *Phrasestream) Reset(w io.Writer) {
	ps.mu.Lock()
	defer ps.mu.Unlock()

	ps.reset(w)
}

// Add adds a new phrase with callback and optional stop.
func (ps *Phrasestream) Add(phrase string, fn Callback, stop Stop) {
	if len(phrase) == 0 {
		return
	}

	ps.mu.Lock()
	defer ps.mu.Unlock()

	hasN := ps.n != nil
	var prefix string
	if hasN {
		prefix = ps.n.prefix
	}

	ps.rad.Insert(phrase, callback{phrase, fn, stop})

	// adjust if our tracked node was split
	if hasN && ps.n.prefix != prefix {
		ps.matching = len(prefix) - len(ps.n.prefix)
	}
}

// Write processes p and writes it to the underlying writer.
func (ps *Phrasestream) Write(p []byte) (int, error) {
	ps.mu.Lock()
	defer ps.mu.Unlock()

	return ps.write(p)
}

// Close flushes any data in internal buffers.
func (ps *Phrasestream) Close() error {
	ps.mu.Lock()
	defer ps.mu.Unlock()

	defer ps.reset(nil)

	return ps.flush()
}

func (ps *Phrasestream) reset(w io.Writer) {
	clear(ps.buf[:cap(ps.buf)])
	clear(ps.cbbuf[:cap(ps.cbbuf)])

	ps.rad = newTree()
	ps.w = w
	ps.buf = ps.buf[:0]
	ps.cbbuf = ps.cbbuf[:0]
	ps.last = nil
	ps.n = nil
	ps.matching = 0
	ps.callback.stop = nil
	ps.callback.fn = nil
}

func indexFunc(s []byte, fn func(byte) bool) int {
	for i, b := range s {
		if fn(b) {
			return i
		}
	}
	return -1
}

// callbackAndReplay calls any associated callback with a previous
// match and then replays any remaining buffer (that we thought we had a
// match for) to be processed again.
func (ps *Phrasestream) callbackAndReplay() error {
	if ps.last.leaf.val != nil {
		if cb, ok := ps.last.leaf.val.(callback); ok && cb.fn != nil {
			if cb.stop == nil {
				if err := cb.fn(ps.w, cb.phrase, nil); err != nil {
					return err
				}
			} else {
				ps.callback = cb
			}
		}
	}

	ps.buf = ps.buf[len(ps.last.leaf.key):]
	ps.last = nil
	ps.matching = 0

	if ps.callback.stop != nil {
		end := indexFunc(ps.buf, ps.callback.stop)

		if end == -1 {
			ps.cbbuf = append(ps.cbbuf[:0], ps.buf...)
			ps.buf = ps.buf[:0]
		} else {
			ps.cbbuf = append(ps.cbbuf[:0], ps.buf[:end]...)
			ps.buf = ps.buf[end:]

			if err := ps.callback.fn(ps.w, ps.callback.phrase, ps.cbbuf); err != nil {
				return err
			}

			ps.cbbuf = ps.cbbuf[:0]
			ps.callback.stop = nil
			ps.callback.fn = nil
		}
	}

	// replay any remaining buffer
	b := ps.buf
	ps.buf = ps.buf[:0]
	if _, err := ps.write(b); err != nil {
		return err
	}

	return nil
}

// advance advances one byte of unmatched plaintext and replays the
// remaining buffer to be processed again.
func (ps *Phrasestream) advance() error {
	if _, err := ps.w.Write(ps.buf[:1]); err != nil {
		return err
	}

	b := ps.buf[1:]
	ps.buf = ps.buf[:0]

	if _, err := ps.write(b); err != nil {
		return err
	}

	return nil
}

func (ps *Phrasestream) write(p []byte) (int, error) {
	if ps.n == nil {
		ps.n = ps.rad.root
	}

	var last int
	for i := 0; i < len(p); i++ {
		if ps.callback.stop != nil {
			if ps.callback.stop(p[i]) {
				if err := ps.callback.fn(ps.w, ps.callback.phrase, ps.cbbuf); err != nil {
					return i, err
				}
				ps.cbbuf = ps.cbbuf[:0]
				ps.callback.stop = nil
				ps.callback.fn = nil
			} else {
				ps.cbbuf = append(ps.cbbuf, p[i])
				last++
				continue
			}
		}

		if ps.n.prefix == "" {
			ps.n = ps.n.getEdge(p[i])
			if ps.n == nil {
				ps.n = ps.rad.root
			}
		}

		if ps.n.prefix != "" && ps.matching == len(ps.n.prefix) {
			ps.matching = 0
			ps.n = ps.n.getEdge(p[i])

			if ps.n == nil {
				ps.n = ps.rad.root

				if len(ps.buf) > 0 {
					ps.matching = 0
					ps.n = ps.rad.root

					if ps.last != nil {
						if err := ps.callbackAndReplay(); err != nil {
							return i, err
						}
					} else {
						if err := ps.advance(); err != nil {
							return i, err
						}
					}
					i--
					continue
				}
			}
		}

		if ps.n.prefix != "" && ps.matching != len(ps.n.prefix) {
			m := min(len(p[i:]), len(ps.n.prefix[ps.matching:]))

			if bytes.HasPrefix(p[i:], []byte(ps.n.prefix[ps.matching:ps.matching+m])) {
				if len(ps.buf) == 0 {
					if _, err := ps.w.Write(p[last:i]); err != nil {
						return i, err
					}
				}

				ps.buf = append(ps.buf, ps.n.prefix[ps.matching:ps.matching+m]...)
				ps.matching += m
				i += m - 1
				last = i + 1

				if ps.matching == len(ps.n.prefix) && ps.n.isLeaf() && ps.n.leaf.key == string(ps.buf) {
					ps.last = ps.n
				}

				continue
			}

			if len(ps.buf) > 0 {
				ps.matching = 0
				ps.n = ps.rad.root

				if ps.last != nil {
					if err := ps.callbackAndReplay(); err != nil {
						return i, err
					}
				} else {
					if err := ps.advance(); err != nil {
						return i, err
					}
				}
				i--
				continue
			}

			ps.n = ps.rad.root
		}
	}

	if last < len(p) {
		n, err := ps.w.Write(p[last:])
		if err != nil {
			return last + n, err
		}
	}

	return len(p), nil
}

func (ps *Phrasestream) flush() error {
	if ps.callback.stop != nil {
		if err := ps.callback.fn(ps.w, ps.callback.phrase, ps.cbbuf); err != nil {
			return err
		}
	}

	ps.n = ps.rad.root
	ps.matching = 0

	if len(ps.buf) > 0 {
		if ps.last != nil {
			if err := ps.callbackAndReplay(); err != nil {
				return err
			}
		} else {
			if err := ps.advance(); err != nil {
				return err
			}
		}
		return ps.flush()
	}

	return nil
}
