package escpos

import (
	"context"
	"fmt"
	"image"
	"io"
	"math"
	"runtime"

	"git.sr.ht/~guacamolie/faxmachine/escpos/printer"
	"git.sr.ht/~guacamolie/faxmachine/escpos/protocol"
)

type Printer struct {
	printer     printer.Printer
	proto       protocol.Protocol
	ownsPrinter bool

	lineDirty bool

	ctx            context.Context
	cancelCtx      context.CancelCauseFunc
	statusResponse chan byte
	asbStatus      chan protocol.ASBStatus
}

func StartPrinter(printer printer.Printer, proto protocol.Protocol) (*Printer, error) {
	p := &Printer{
		printer:        printer,
		proto:          proto,
		ownsPrinter:    false,
		lineDirty:      false,
		statusResponse: make(chan byte),
		asbStatus:      make(chan protocol.ASBStatus),
	}

	if err := p.writeInstr(mustInstr(p.proto.InitializePrinter())); err != nil {
		return nil, err
	}

	p.ctx, p.cancelCtx = context.WithCancelCause(context.Background())

	go func() {
		for p.ctx.Err() == nil {
			var buf [4]byte
			n, err := p.printer.Read(buf[:])
			if err == io.EOF {
				runtime.Gosched()
				continue
			}
			if err != nil {
				p.cancelCtx(fmt.Errorf("failed to read from printer: %w", err))
				break
			}
			switch n {
			case 1:
				p.statusResponse <- buf[1]
			case 4:
				p.asbStatus <- p.proto.ParseASBStatus(buf)
			}
		}

		close(p.statusResponse)
		close(p.asbStatus)
	}()

	return p, nil
}

const (
	FlagNone  = 0
	FlagDebug = 1 << 0
)

func StartUSBPrinter(path string, proto protocol.Protocol, flags int) (*Printer, error) {
	usbPrinter, err := printer.OpenUSBPrinter(path, flags)
	if err != nil {
		return nil, err
	}

	p, err := StartPrinter(usbPrinter, proto)
	if err != nil {
		usbPrinter.Close()
		return nil, err
	}
	p.ownsPrinter = true

	return p, nil
}

func (p *Printer) Close() error {
	p.cancelCtx(nil)
	if p.ownsPrinter {
		return p.printer.Close()
	}
	return nil
}

func mustInstr(instr []byte, err error) []byte {
	if err != nil {
		panic(fmt.Sprintf("library bug: %v", err))
	}
	return instr
}

func (p *Printer) writeInstr(instr []byte) error {
	_, err := p.printer.Write(instr)
	if err != nil {
		return &IOError{err}
	}
	return nil
}

func (p *Printer) Write(data []byte) (int, error) {
	return p.WriteString(string(data))
}

func (p *Printer) WriteString(s string) (n int, err error) {
	err = p.Print(Text(s))
	if err != nil {
		return 0, err
	}

	// We can't just return the amount of written characters here, since that
	// may not equal the amount of input characters after conversion. Code
	// calling Write and WriteString often expect n to equal len(s) on success,
	// to just return len(s) to satisfy this.
	return len(s), nil
}

func (p *Printer) Print(c Component) error {
	instructions, err := c.render(p.proto)
	if err != nil {
		return err
	}

	for _, inst := range instructions {
		if _, err := p.printer.Write(inst); err != nil {
			return &IOError{err}
		}

		if len(inst) > 0 && inst[len(inst)-1] == '\n' {
			p.lineDirty = false
		} else {
			p.lineDirty = true
		}
	}

	return nil
}

func (p *Printer) SetPrintSpeed(speed int) error {
	if speed > math.MaxUint8 {
		return fmt.Errorf("invalid print speed %d", speed)
	}
	instr, err := p.proto.SelectPrintSpeed(uint8(speed))
	if err != nil {
		return fmt.Errorf("printer does not support print speed %d: %v", speed, err)
	}
	return p.writeInstr(instr)
}

func (p *Printer) PrintImage(img image.Image) error {
	if p.lineDirty {
		return &LineDirtyError{"print image"}
	}

	x, y, data := getPrintImageData(img)

	return p.printImage(x, y, data)
}

func (p *Printer) printImage(x int, y int, data []byte) error {
	maxY := 1662

	if y > maxY {
		err := p.printImage(x, maxY, data[:(x*maxY)>>3])
		if err != nil {
			return fmt.Errorf("failed to print first half: %w", err)
		}

		// Wait until we finished to avoid overflowing the printer's buffer.
		if err := p.Wait(); err != nil {
			return fmt.Errorf("error printing first half: %w", err)
		}

		err = p.printImage(x, y-maxY, data[(x*maxY)>>3:])
		if err := p.Wait(); err != nil {
			return fmt.Errorf("error printing second half: %w", err)
		}

		return nil
	}

	storeInstr, err := p.proto.StoreGraphicsData(1, 1, protocol.Color1, uint16(x), uint16(y), data...)
	if err != nil {
		return fmt.Errorf("image does not confirm to limitations of printer: %w", err)
	}
	printInstr := mustInstr(p.proto.PrintGraphicsData())

	if err := p.writeInstr(storeInstr); err != nil {
		return fmt.Errorf("failed to send image data to printer: %w", err)
	}
	if err := p.writeInstr(printInstr); err != nil {
		return fmt.Errorf("failed to send print instruction to printer: %w", err)
	}
	return nil
}

func (p *Printer) CutPaper() error {
	if p.lineDirty {
		return &LineDirtyError{"cut paper"}
	}
	instr := mustInstr(p.proto.CutPaper(protocol.FeedAndPartialCut, 0))
	return p.writeInstr(instr)
}

func (p *Printer) Wait() error {
	instr := mustInstr(p.proto.TransmitStatus(protocol.TransmitPaperSensorStatus))
	if err := p.writeInstr(instr); err != nil {
		return fmt.Errorf("failed to request status from printer: %w", err)
	}

	select {
	case <-p.ctx.Done():
		return context.Cause(p.ctx)
	case <-p.statusResponse:
		return nil
	}
}

func (p *Printer) EnableASB(flags int) error {
	instr, err := p.proto.SetAutomaticStatusBack(uint8(flags))
	if err != nil {
		return fmt.Errorf("failed to request Automatic Status Back (ASB) from printer: %w", err)
	}
	return p.writeInstr(instr)
}

func (p *Printer) DisableASB() error {
	return p.EnableASB(protocol.ASBReportNothing)
}

func (p *Printer) ASBStatus() <-chan protocol.ASBStatus {
	return p.asbStatus
}

// LineDirtyError is returned when the action cannot be performed because the
// printer needs to be in the "begining of the line" state for action to be
// performed, but it has been determined that the printer won't be in this
// state when it will process the command.
//
// This usually happens when character data is written to the printer without a
// linefeed ('\n') at the end.
type LineDirtyError struct {
	attemptedAction string
}

func (err *LineDirtyError) Error() string {
	return fmt.Sprintf("can only %s when at the beginning of the line", err.attemptedAction)
}

type IOError struct {
	wrapped error
}

func (err *IOError) Error() string {
	return fmt.Sprintf("io error when printing: %v", err.wrapped)
}

func (err *IOError) Unwrap() error {
	return err.wrapped
}