# -*- coding: utf-8 -*-
from __future__ import division, print_function

import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

from reid.model_3dta import HUANG_3DTA
from utils.logger import LOGGER


def load_3dta_data(test_dir, batchsize, folders):
    ###############change the input size######################
    data_transforms = transforms.Compose([
        transforms.Resize((384, 192), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_datasets = {x: ImageFolder(os.path.join(test_dir, x), data_transforms, )
                      for x in folders}
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=batchsize,
                                 shuffle=False, num_workers=8) for x in folders}
    use_gpu = torch.cuda.is_available()
    return image_datasets, dataloaders, use_gpu


def load_3dta(which_epoch, name, batchsize):
    iglc = 10
    picnum = 2
    model_structure = HUANG_3DTA(767 * iglc, picnum, batchsize)
    save_path = os.path.join(name, 'net_{:s}.pth'.format(which_epoch))
    model_structure.load_state_dict(torch.load(save_path))
    return model_structure


def fliplr(img):
    """flip horizontal"""
    inv_idx = torch.arange(img.size(3) - 1, -1, -1).long()  # N x C x H x W
    img_flip = img.index_select(3, inv_idx)
    return img_flip


def extract_feature(model, dataloader):
    features = torch.FloatTensor()
    for img, label in dataloader:
        n, c, h, w = img.size()
        ff = torch.FloatTensor(n, 2048, 6).zero_()  # we have six parts
        for i in range(2):
            if (i == 1):
                img = fliplr(img)
            input_img = Variable(img.cuda())
            outputs = model(input_img)
            f = outputs.data.cpu()
            ff = ff + f
        # norm feature
        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6)
        ff = ff.div(fnorm.expand_as(ff))
        ff = ff.view(ff.size(0), -1)
        features = torch.cat((features, ff), 0)
    return features


def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        filename = os.path.basename(path)
        label = filename[0:4]
        camera = filename.split('c')[1]
        if label[0:2] == '-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera[0]))
    return camera_id, labels


def imshow(path, title=None):
    im = plt.imread(path)
    plt.imshow(im)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


def sort_img(qf, ql, qc, gf, gl, gc):
    query = qf.view(-1, 1)
    score = torch.mm(gf, query)
    score = score.squeeze(1).cpu()
    score = score.numpy()
    index = np.argsort(score)  # from small to large
    index = index[::-1]
    query_index = np.argwhere(gl == ql)
    camera_index = np.argwhere(gc == qc)
    junk_index1 = np.argwhere(gl == -1)
    junk_index2 = np.intersect1d(query_index, camera_index)
    junk_index = np.append(junk_index2, junk_index1)
    mask = np.in1d(index, junk_index, invert=True)
    index = index[mask]
    return index, score


def output_batch(image_datasets, dataloaders, gallery_feature, model, query_data='query'):
    scores = dict()
    if len(gallery_feature) == 0:
        LOGGER.warning('gallery feature is empty')
        return scores
    gallery_path = image_datasets['gallery'].imgs
    query_path = image_datasets[query_data].imgs
    gallery_cam, gallery_label = get_id(gallery_path)
    query_cam, query_label = get_id(query_path)
    query_feature = extract_feature(model, dataloaders[query_data])
    for i, q in enumerate(query_path):
        index, score = sort_img(query_feature[i], query_label[i], query_cam[i], gallery_feature,
                                gallery_label, gallery_cam)
        temp = q[0].split(os.sep)[-1].split('_')[1:]  # variable temp looks like ['0001','c1','1579491797.jpg']
        query_id = '_'.join(temp).split('.')[0]  # find a unique id of every query image like 'c1_1579491797'
        scores[query_id] = dict()
        for ii in range(len(index)):
            _gallery_result = image_datasets['gallery'].imgs[index[ii]][0].split(os.sep)[-1].split('_')[0]
            if _gallery_result in scores[query_id].keys():
                if score[index[ii]] > scores[query_id][_gallery_result]:
                    scores[query_id][_gallery_result] = score[index[ii]]
            else:
                scores[query_id][_gallery_result] = score[index[ii]]
    return scores
