package rocket_m

import (
	"git.hilo.cn/hilo-common/resource/mysql"
	"git.hilo.cn/hilo-common/utils"
	"gorm.io/gorm"
	"time"
)

type RocketContribute struct {
	mysql.Entity
	GroupId   string
	Period    string
	Round     uint16
	Stage     uint16
	UserId    uint64
	GiftRefId uint64
	Diamond   uint32
}

func (rc *RocketContribute) Create(db *gorm.DB) error {
	return db.Create(rc).Error
}

func (rc *RocketContribute) Get(db *gorm.DB) ([]RocketContribute, error) {
	rows := make([]RocketContribute, 0)
	err := db.Where(rc).Find(&rows).Error
	if err != nil {
		return nil, err
	}
	return rows, err
}

type ContributeType struct {
	UserId  uint64
	Diamond uint32
}

func (rc *RocketContribute) SumByStageUser(db *gorm.DB, round uint16) (map[uint16][]ContributeType, error) {
	type rowType struct {
		Stage  uint16
		UserId uint64
		S      uint32
	}
	rows := make([]rowType, 0)
	err := db.Model(rc).Select("stage, user_id,SUM(diamond) AS s").
		Where(rc).Where("round = ?", round).
		Group("stage, user_id").Order("s DESC, created_time ASC").Find(&rows).Error
	if err != nil {
		return nil, err
	}
	result := make(map[uint16][]ContributeType, 0)
	for _, i := range rows {
		if _, ok := result[i.Stage]; !ok {
			result[i.Stage] = make([]ContributeType, 0)
		}
		result[i.Stage] = append(result[i.Stage], ContributeType{
			UserId:  i.UserId,
			Diamond: i.S,
		})
	}
	return result, err
}

func (rc *RocketContribute) SumByUser(db *gorm.DB, round, stage uint16) ([]ContributeType, error) {
	type rowType struct {
		UserId uint64
		S      uint32
	}
	rows := make([]rowType, 0)
	err := db.Model(rc).Select("user_id,SUM(diamond) AS s").
		Where(rc).Where("round = ? AND stage = ?", round, stage).
		Group("user_id").Order("s DESC, created_time ASC").Find(&rows).Error
	if err != nil {
		return nil, err
	}
	result := make([]ContributeType, 0)
	for _, i := range rows {
		result = append(result, ContributeType{
			UserId:  i.UserId,
			Diamond: i.S,
		})
	}
	return result, err
}

type RocketResult struct {
	mysql.Entity
	GroupId   string
	Period    string
	Round     uint16
	Stage     uint16
	IsAwarded bool
}

func (rr *RocketResult) Create(db *gorm.DB) error {
	return db.Create(rr).Error
}

func (rr *RocketResult) Get(db *gorm.DB) ([]RocketResult, error) {
	rows := make([]RocketResult, 0)
	err := db.Model(&RocketResult{}).Where(rr).Find(&rows).Error
	if err != nil {
		return nil, nil
	}
	return rows, nil
}

func (rr *RocketResult) GetByRound(db *gorm.DB, round uint16) (map[uint16]RocketResult, error) {
	rows := make([]RocketResult, 0)
	err := db.Model(&RocketResult{}).Where(rr).Where("round = ?", round).Find(&rows).Error
	if err != nil {
		return nil, nil
	}
	result := make(map[uint16]RocketResult, 0)
	for _, i := range rows {
		result[i.Stage] = i
	}
	return result, nil
}

func (rr *RocketResult) GetByTopRound(db *gorm.DB) (map[uint16]RocketResult, error) {
	rows := make([]RocketResult, 0)
	err := db.Model(&RocketResult{}).Where(rr).Find(&rows).Error
	if err != nil {
		return nil, nil
	}

	topRound := -1
	for _, i := range rows {
		if int(i.Round) > topRound {
			topRound = int(i.Round)
		}
	}
	result := make(map[uint16]RocketResult, 0)
	for _, i := range rows {
		if int(i.Round) == topRound {
			result[i.Stage] = i
		}
	}
	return result, nil
}

func (rr *RocketResult) GetValid(db *gorm.DB, t time.Time) ([]RocketResult, error) {
	rows := make([]RocketResult, 0)
	err := db.Model(&RocketResult{}).Where(rr).Where("created_time > ?", t).Find(&rows).Error
	if err != nil {
		return nil, nil
	}
	return rows, nil
}

func (rr *RocketResult) UpdateIsAwarded(db *gorm.DB) (int64, error) {
	result := db.Model(&RocketResult{}).Where(rr).Update("is_awarded", true)
	if result.Error != nil {
		return 0, nil
	}
	return result.RowsAffected, nil
}

// 取本自然周内最高级的一次火箭
func (rr *RocketResult) GetMaxStage(db *gorm.DB, groupIds []string) (map[string]uint16, error) {
	type maxStage struct {
		GroupId string
		M       uint16
	}
	period := utils.GetMonday(time.Now()).Format(utils.DATE_FORMAT)
	rows := make([]maxStage, 0)
	err := db.Model(&RocketResult{}).Select("group_id, MAX(stage) AS m").Where(rr).
		Where("group_id IN ? AND period >= ?", groupIds, period).Group("group_id").Find(&rows).Error
	if err != nil {
		return nil, nil
	}

	result := make(map[string]uint16, 0)
	for _, i := range rows {
		result[i.GroupId] = i.M
	}
	return result, nil
}