package cleanup

import (
	"bufio"
	"context"
	"fmt"
	"io"
	"strings"

	"gitlab.com/gitlab-org/gitaly/v16/internal/git"
	"gitlab.com/gitlab-org/gitaly/v16/internal/git/updateref"
	"gitlab.com/gitlab-org/gitaly/v16/internal/log"
)

// forEachFunc can be called for every entry in the filter-repo or BFG object
// map file that the cleaner is processing. Returning an error will stop the
// cleaner before it has processed the entry in question
type forEachFunc func(ctx context.Context, oldOID, newOID string, isInternalRef bool) error

// cleaner is responsible for updating the internal references in a repository
// as specified by a filter-repo or BFG object map. Currently, internal
// references pointing to a commit that has been rewritten will simply be
// removed.
type cleaner struct {
	ctx     context.Context
	forEach forEachFunc
	logger  log.Logger

	// Map of SHA -> reference names
	table       map[string][]git.ReferenceName
	removedRefs map[git.ReferenceName]struct{}
	repo        git.RepositoryExecutor
}

// errInvalidObjectMap is returned with descriptive text if the supplied object
// map file is in the wrong format
type errInvalidObjectMap error

// newCleaner builds a new instance of Cleaner, which is used to apply a
// filter-repo or BFG object map to a repository.
func newCleaner(ctx context.Context, logger log.Logger, repo git.RepositoryExecutor, forEach forEachFunc) (*cleaner, error) {
	table, err := buildLookupTable(ctx, logger, repo)
	if err != nil {
		return nil, err
	}

	removedRefs := make(map[git.ReferenceName]struct{})
	return &cleaner{ctx: ctx, logger: logger, table: table, removedRefs: removedRefs, repo: repo, forEach: forEach}, nil
}

// applyObjectMap processes an object map file generated by git filter-repo, or
// BFG, removing any internal references that point to a rewritten commit.
func (c *cleaner) applyObjectMap(ctx context.Context, reader io.Reader) (returnedErr error) {
	objectHash, err := c.repo.ObjectHash(ctx)
	if err != nil {
		return fmt.Errorf("detecting object hash: %w", err)
	}

	updater, err := updateref.New(ctx, c.repo)
	if err != nil {
		return fmt.Errorf("new updater: %w", err)
	}
	defer func() {
		if err := updater.Close(); err != nil && returnedErr == nil {
			returnedErr = fmt.Errorf("close updater: %w", err)
		}
	}()

	if err := updater.Start(); err != nil {
		return fmt.Errorf("start reference transaction: %w", err)
	}

	scanner := bufio.NewScanner(reader)
	for i := int64(0); scanner.Scan(); i++ {
		line := scanner.Text()

		const filterRepoCommitMapHeader = "old                                      new"
		if line == filterRepoCommitMapHeader {
			continue
		}

		// Each line consists of two SHAs: the SHA of the original object, and
		// the SHA of a replacement object in the new repository history. For
		// now, the new SHA is ignored, but it may be used to rewrite (rather
		// than remove) some references in the future.
		oldOID, newOID, ok := strings.Cut(line, " ")
		if !ok {
			return errInvalidObjectMap(fmt.Errorf("object map invalid at line %d", i))
		}
		if err := objectHash.ValidateHex(oldOID); err != nil {
			return errInvalidObjectMap(fmt.Errorf("invalid old object ID at line %d", i))
		}
		if err := objectHash.ValidateHex(newOID); err != nil {
			return errInvalidObjectMap(fmt.Errorf("invalid new object ID at line %d", i))
		}

		// References to unchanged objects do not need to be removed. When the old
		// SHA and new SHA are the same, this means the object was considered but
		// not modified.
		if oldOID == newOID {
			continue
		}

		if err := c.processEntry(ctx, updater, oldOID, newOID); err != nil {
			return err
		}
	}

	return updater.Commit()
}

func (c *cleaner) processEntry(ctx context.Context, updater *updateref.Updater, oldSHA, newSHA string) error {
	refs, isPresent := c.table[oldSHA]

	if c.forEach != nil {
		if err := c.forEach(ctx, oldSHA, newSHA, isPresent); err != nil {
			return err
		}
	}

	if !isPresent {
		return nil
	}

	c.logger.WithFields(log.Fields{
		"sha":  oldSHA,
		"refs": refs,
	}).InfoContext(ctx, "removing internal references")

	// Remove the internal refs pointing to oldSHA
	for _, ref := range refs {
		// Delete each ref only once.
		if _, ok := c.removedRefs[ref]; ok {
			continue
		}
		c.removedRefs[ref] = struct{}{}

		if err := updater.Delete(ref); err != nil {
			return err
		}
	}

	return nil
}

// buildLookupTable constructs an in-memory map of SHA -> refs. Multiple refs
// may point to the same SHA.
//
// The lookup table is necessary to efficiently check which references point to
// an object that has been rewritten by the filter-repo or BFG (and so require
// action). It is consulted once per line in the object map. Git is optimized
// for ref -> SHA lookups, but we want the opposite!
func buildLookupTable(ctx context.Context, logger log.Logger, repo git.RepositoryExecutor) (map[string][]git.ReferenceName, error) {
	objectHash, err := repo.ObjectHash(ctx)
	if err != nil {
		return nil, fmt.Errorf("detecting object hash: %w", err)
	}

	internalRefPrefixes := make([]string, 0, len(git.InternalRefPrefixes))
	for refPrefix := range git.InternalRefPrefixes {
		internalRefPrefixes = append(internalRefPrefixes, refPrefix)
	}

	cmd, err := repo.Exec(ctx, git.Command{
		Name:  "for-each-ref",
		Flags: []git.Option{git.ValueFlag{Name: "--format", Value: "%(objectname) %(parent) %(refname)"}},
		Args:  internalRefPrefixes,
	}, git.WithSetupStdout())
	if err != nil {
		return nil, err
	}

	out := make(map[string][]git.ReferenceName)
	scanner := bufio.NewScanner(cmd)

	for scanner.Scan() {
		line := scanner.Text()

		output := strings.Fields(line)
		if len(output) < 2 {
			logger.WithFields(log.Fields{"line": line}).WarnContext(ctx, "failed to parse git refs")
			return nil, fmt.Errorf("failed to parse git refs")
		}

		refName := output[len(output)-1]
		for _, objectName := range output[:len(output)-1] {
			if err := objectHash.ValidateHex(objectName); err != nil {
				return nil, fmt.Errorf("failed to parse object name: %w", err)
			}

			out[objectName] = append(out[objectName], git.ReferenceName(refName))
		}
	}

	if err := cmd.Wait(); err != nil {
		return nil, err
	}

	if err := scanner.Err(); err != nil {
		return nil, err
	}

	return out, nil
}
