import copy
import multiprocessing
import os
import queue
from pprint import pformat
from time import sleep, time

import cv2
import torch
from reid.model import Model3DTA

from config import config as cfg
from utils.drawing import drawing
from utils.hist import comp
from utils.capture import fetch, start_streams
from utils.folder import check_folder, clear_folder, has_subfolder
from utils.initialize import initialize
from utils.logger import LOGGER
from utils.position import is_contain
from yolov3.yolo_detect import set_yolo
from zone import get_zone


def if_need_to_update(existing_ids, gallery_path):
    file_list = []
    temp = dict()
    for _, _, files in os.walk(cfg.REID.GALLERY):
        for file in files:
            temp[file.split('_')[-1][:-4]] = 1
            file_list.append(file.split('_')[-1][:-4])
    it = 0
    need_to_reload = 0
    for it, cur_id in enumerate(file_list):
        if cur_id not in existing_ids.keys():
            need_to_reload = 1
            break
    if it < len(existing_ids) - 1:
        need_to_reload = 1
    return temp, need_to_reload


def reid_main(cfg):
    # 初始化变量
    urls = cfg.CAMERA.REID_URLS
    reid_cams = [uri[0] for uri in urls]
    zones = get_zone()
    region = sorted(list(set([z.region_id for z in zones if z.camera_id in reid_cams])))
    region_number = len(region)
    LOGGER.info('\n\t region: {:}\n\t#region: {:d}'.format(region, region_number))
    LOGGER.debug('REID profile:\n\tOnline mode: {:} \n\tEnable Hist strategy: {:}'.format(bool(cfg.REID.ONLINE),
                                                                                          bool(cfg.REID.ENABLE_HIST)))
    id_and_time = dict()
    id_and_color = dict()  # for drawing ids
    vid_writers = []
    imgs = []
    # for hist
    new_target_for_hist_comp, old_target_for_hist_comp = dict(), dict()
    new_target_for_hist_comp_index, old_target_for_hist_comp_index = dict(), dict()
    pnum_every_zone, old_pnum_every_zone = dict(), dict()
    flag = 0

    num_person = 0

    gobalflag = 0
    need_to_reid = 0
    existing_ids = dict()  # for update gallery

    yolo_det_queue = cfg.YOLO.REID_RECV
    yolo_put_queue = cfg.YOLO.REID_PUT

    check_folder(cfg.REID.UNKNOWN, cfg.REID.GALLERY, cfg.YOLO.OUTPUT, cfg.REID.RESULT_DIR)
    for cam in reid_cams:
        check_folder(os.path.join(cfg.REID.RESULT_DIR, cam))
        clear_folder(os.path.join(cfg.REID.RESULT_DIR, cam))

    # 初始化3dta
    with torch.no_grad():
        model_3dta = Model3DTA(cfg)

    last = time()
    if has_subfolder(cfg.REID.GALLERY):
        model_3dta.update_gallery()

    existing_ids, _ = if_need_to_update(existing_ids, cfg.REID.GALLERY)
    while True:
        lapse = time() - last
        last = time()
        clear_folder(cfg.REID.UNKNOWN)
        _start = time()
        gobalflag += 1
        frame_to_draw = dict()

        if not need_to_reid:
            imgs = fetch(cfg, 'reid')
        else:
            gobalflag -= 1
            flag -= 1
            imgs = imgs
        for i in imgs:
            _cam_id = i[0]
            frame_to_draw[_cam_id] = copy.deepcopy(i)
        if not cfg.REID.ONLINE:
            LOGGER.info('current frame:{:}'.format(gobalflag))

        if not imgs:
            LOGGER.debug('no frame fetch!')
            sleep(1 / cfg.REID.CAMERA_FPS)
            continue
        LOGGER.debug('got {:d} frames.'.format(len(imgs)))

        if not has_subfolder(cfg.REID.GALLERY):
            LOGGER.debug('gallery is empty!')
            continue

        person_gallery = os.listdir(cfg.REID.GALLERY)
        if len(person_gallery) > num_person:
            ids = list(id_and_time.keys())
            num_person = len(ids)
            for p in person_gallery:
                if p in ids:
                    continue
                tmp = dict()
                id_color = (int(255 * (person_gallery.index(p) / len(person_gallery))),
                            int(255 * (person_gallery.index(p) / len(person_gallery))),
                            255 - int(255 * (person_gallery.index(p) / len(person_gallery))))
                for r in region:
                    tmp[r] = 0
                id_and_time[p] = tmp
                id_and_color[p] = id_color
            LOGGER.debug("num of reid person changed.")

        existing_ids, need_to_reload = if_need_to_update(existing_ids, cfg.REID.GALLERY)
        if need_to_reload:
            model_3dta.update_gallery()

        yolo_det_queue.put(imgs, timeout=10)
        try:
            person_results = yolo_put_queue.get(timeout=10)
        except queue.Empty:
            LOGGER.warning('no yolo results receive.')
            person_results = []

        LOGGER.debug('number of people detected by YOLO: {:d}'.format(len(person_results)))
        if len(person_results) >= 1:
            in_zone_list_and_coodrd = []
            for c, a in urls:
                new_target_for_hist_comp[c], new_target_for_hist_comp_index[c], pnum_every_zone[
                    c] = dict(), dict(), dict()
                for z in zones:
                    new_target_for_hist_comp[c][z.region_id], new_target_for_hist_comp_index[c][z.region_id] = [], []
                    pnum_every_zone[c][z.region_id] = 0
            for i, p in enumerate(person_results):
                cam_id, img_id, x1, y1, x2, y2 = p[:6]
                img = imgs[img_id][1]
                person_foot = (int((x1 + x2) / 2), y2)
                cur_camera_zones = [(z.region_id, z.vertices) for z in zones if z.camera_id == cam_id]
                for cur_zone_id, cur_zone_vertices in cur_camera_zones:
                    if is_contain(person_foot, cur_zone_vertices, contain_boundary=False):
                        target = img[y1:y2, x1:x2]
                        cv2.imwrite(
                            os.path.join(cfg.REID.UNKNOWN, '0001_{:s}_{:d}.jpg'.format(
                                cam_id, int(time()) + i)), target)
                        cv2.rectangle(frame_to_draw[cam_id][1], (x1, y1), (x2, y2), (0, 255, 0), 2)
                        in_zone_list_and_coodrd.append((cur_zone_id, (x1, y1), cam_id))

                        new_target_for_hist_comp[cam_id][cur_zone_id].append(
                            target)  # 将target分摄像头和区域存起来，并且记录其在in_zone_list中的index
                        new_target_for_hist_comp_index[cam_id][cur_zone_id].append(
                            (len(in_zone_list_and_coodrd) - 1, (x1, y1)))
                        pnum_every_zone[cam_id][cur_zone_id] += 1
                        break  # to ensure that one person only stays in one region
            LOGGER.debug("flag:{}".format(flag))
            if in_zone_list_and_coodrd:
                flag = 0 if flag >= 10 else flag
                if (not cfg.REID.ENABLE_HIST) or (flag == 0) or need_to_reid:
                    flag = 0
                    need_to_reid = 0
                    LOGGER.debug("calling 3dta")
                    imgs_path, scores = model_3dta.predict()
                    LOGGER.debug('imgs_path:{:},sorce:{:}'.format(imgs_path, scores))
                    region_and_customer = dict()
                    for region_id in region:
                        region_and_customer[region_id] = set()

                    for cur_zone, img_path, score in zip(in_zone_list_and_coodrd, imgs_path, scores):
                        if score and score >= cfg.REID.TH:
                            customer_id = img_path.split(os.sep)[-2]
                            region_and_customer[cur_zone[0]].add(customer_id)
                            LOGGER.info('cam{:s}:customer ID {:s} enter zone {:s}'.format(cur_zone[2][1:], customer_id,
                                                                                          cur_zone[0]))
                            temp = drawing(task='str', imgs_dict=[frame_to_draw[cur_zone[2]]],
                                           info="{:}".format(customer_id),
                                           text_color=id_and_color[customer_id],
                                           pt=[cur_zone[1]])
                            frame_to_draw[cur_zone[2]][1] = temp[0]
                        else:
                            LOGGER.info('unknown customer enter zone-{:s}'.format(cur_zone[0]))
                    flag += 1
                else:
                    if (pnum_every_zone != old_pnum_every_zone) or (len(in_zone_list_and_coodrd) != len(imgs_path)):
                        need_to_reid = 1
                        LOGGER.info('cannot match person, call 3dta.')
                        continue
                    resort_dict, region_and_customer = dict(), dict()
                    for region_id in region:
                        region_and_customer[region_id] = set()
                    for c, a in urls:
                        if need_to_reid:
                            break
                        for z in region:
                            hist_result = comp(old_target_for_hist_comp[c][z],
                                               new_target_for_hist_comp[c][z])
                            if (hist_result is None) or (None in hist_result.values()):
                                need_to_reid = 1
                                LOGGER.info('cannot match person, call 3dta.')
                                break
                            else:
                                for i, r in hist_result.items():
                                    _index = old_target_for_hist_comp_index[c][z][i][0]
                                    customer_id = imgs_path[_index].split(os.sep)[-2]
                                    if scores[i] > cfg.REID.TH:
                                        region_and_customer[z].add(customer_id)
                                        temp = drawing(task='str',
                                                       imgs_dict=[frame_to_draw[c]],
                                                       info="{:}".format(customer_id),
                                                       text_color=id_and_color[customer_id],
                                                       pt=[new_target_for_hist_comp_index[c][z][r][1]])
                                        frame_to_draw[c][1] = temp[0]
                                        LOGGER.info(
                                            'cam{:} : customer ID {:s} stays in zone {:s}'.format(c[1:],
                                                                                                  customer_id,
                                                                                                  z))
                                    resort_dict[_index] = new_target_for_hist_comp_index[c][z][r][0]
                    if not need_to_reid:
                        temp = imgs_path.copy()
                        for ii, _ in enumerate(imgs_path):
                            new_order = resort_dict[ii]
                            temp[new_order] = imgs_path[ii]
                        imgs_path = temp.copy()
                    flag += 1
                if not need_to_reid:
                    for region_id in region_and_customer.keys():
                        tmp = sorted(list(region_and_customer[region_id]))
                        region_and_customer[region_id] = tmp
                        for customer_id in tmp:
                            store_time = id_and_time[customer_id][region_id]
                            id_and_time[customer_id][region_id] = round(store_time + lapse, 3) if cfg.REID.ONLINE \
                                else (store_time + 1)

            old_target_for_hist_comp = copy.deepcopy(new_target_for_hist_comp)
            old_target_for_hist_comp_index = copy.deepcopy(new_target_for_hist_comp_index)
            old_pnum_every_zone = copy.deepcopy(pnum_every_zone)

            print({k: {kk: vv for kk, vv in v.items() if vv != 0} for k, v in id_and_time.items()})

            with open(cfg.REID.RESULT, 'a') as f:
                f.write(pformat(id_and_time, indent=2))
                f.write('\n')

        else:
            LOGGER.debug('no people detected.')

        if cfg.REID.SAVE_VIDEO or cfg.REID.SAVE_IMG:  # draw results
            drawing_results = drawing('dict', frame_to_draw, id_and_time)
            valid_cams = frame_to_draw.keys()
            if cfg.REID.SAVE_VIDEO:
                if len(vid_writers) < len(drawing_results):
                    for c, i in zip(valid_cams, range(len(drawing_results))):
                        height, width = drawing_results[i].shape[0], drawing_results[i].shape[1]
                        vid_writers.append(cv2.VideoWriter(os.path.join(cfg.REID.RESULT_DIR, '{:}.mp4'.format(c)),
                                                           cv2.VideoWriter_fourcc(*cfg.YOLO.FOURCC),
                                                           cfg.REID.SAVE_VIDEO_FPS, (width, height)))
                vid_writers[i].write(img)
            if cfg.REID.SAVE_IMG:
                for i, img in zip(valid_cams, drawing_results):
                    cv2.imwrite(os.path.join(cfg.REID.RESULT_DIR, '{:}'.format(i), '{:d}.jpg'.format(gobalflag)),
                                img)

        LOGGER.debug('frame delay: {:}'.format(time() - last))


if __name__ == '__main__':
    os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'
    multiprocessing.set_start_method('forkserver', force=True)
    initialize(cfg)
    start_streams(cfg, ('reid',))
    set_yolo(cfg, ('reid',))
    reid_main(cfg)
