import cv2
import numpy as np
import onnxruntime as ort
import time
from froxiaofeidriver.myvideo import Video

#交通标志检测，可以检测前进、后退、左转、右转
#检测算法：yolo-lite-master
class TrafficSign(Video):
    
    # 模型加载
    model_pb_path = "/home/pi/AiCar/mylib/trafficsign.onnx"
    so = ort.SessionOptions()
    net = ort.InferenceSession(model_pb_path, so)
    
    width = 320
    height = 240
    dim = (width,height) #压缩以后的视频尺寸，避免视频帧过大造成显示卡顿
    
    # 标签字典
    dic_labels= {0:'turn_left',
                1:'turn_right',
                2:'go_ahead',
                3:'go_back'}
    count = 0
    last_sign = 0
    # 模型参数
    model_h = 320
    model_w = 320
    nl = 3
    na = 3
    stride=[8.,16.,32.]
    anchors = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]
    anchor_grid = np.asarray(anchors, dtype=np.float32).reshape(nl, -1, 2)
#     cap = cv2.VideoCapture(0)
        
    def __init__(self):
        pass
    
    def plot_one_box(self,x, img, color=None, label=None, line_thickness=None):
        tl = (
            line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
        )  # line/font thickness
        color = color or [random.randint(0, 255) for _ in range(3)]
        c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
        cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
        if label:
            tf = max(tl - 1, 1)  # font thickness
            t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
            c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
            cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
            cv2.putText(
                img,
                label,
                (c1[0], c1[1] - 2),
                0,
                tl / 3,
                [225, 255, 255],
                thickness=tf,
                lineType=cv2.LINE_AA,
            )
       
    def _make_grid(self,nx, ny):
        xv, yv = np.meshgrid(np.arange(ny), np.arange(nx))
        return np.stack((xv, yv), 2).reshape((-1, 2)).astype(np.float32)
    
    def cal_outputs(self,outs,nl,na,model_w,model_h,anchor_grid,stride):
    
        row_ind = 0
        grid = [np.zeros(1)] * nl
        for i in range(nl):
            h, w = int(model_w/ stride[i]), int(model_h / stride[i])
            length = int(na * h * w)
            if grid[i].shape[2:4] != (h, w):
                grid[i] = self._make_grid(w, h)

            outs[row_ind:row_ind + length, 0:2] = (outs[row_ind:row_ind + length, 0:2] * 2. - 0.5 + np.tile(
                grid[i], (na, 1))) * int(stride[i])
            outs[row_ind:row_ind + length, 2:4] = (outs[row_ind:row_ind + length, 2:4] * 2) ** 2 * np.repeat(
                anchor_grid[i], h * w, axis=0)
            row_ind += length
        return outs
    def post_process_opencv(self,outputs,model_h,model_w,img_h,img_w,thred_nms,thred_cond):
        conf = outputs[:,4].tolist()
        c_x = outputs[:,0]/model_w*img_w
        c_y = outputs[:,1]/model_h*img_h
        w  = outputs[:,2]/model_w*img_w
        h  = outputs[:,2]/model_h*img_h
        p_cls = outputs[:,5:]
        if len(p_cls.shape)==1:
            p_cls = np.expand_dims(p_cls,1)
        cls_id = np.argmax(p_cls,axis=1)

        p_x1 = np.expand_dims(c_x-w/2,-1)
        p_y1 = np.expand_dims(c_y-h/2,-1)
        p_x2 = np.expand_dims(c_x+w/2,-1)
        p_y2 = np.expand_dims(c_y+h/2,-1)
        areas = np.concatenate((p_x1,p_y1,p_x2,p_y2),axis=-1)
        
        areas = areas.tolist()
        ids = cv2.dnn.NMSBoxes(areas,conf,thred_cond,thred_nms)
        if len(ids)>0:
            return  np.array(areas)[ids],np.array(conf)[ids],cls_id[ids]
        else:
            return [],[],[]
    def infer_img(self,img0,net,model_h,model_w,nl,na,stride,anchor_grid,thred_nms=0.4,thred_cond=0.5):
        # 图像预处理
        img = cv2.resize(img0, [model_w,model_h], interpolation=cv2.INTER_AREA)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0
        blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)

        # 模型推理
        outs = net.run(None, {net.get_inputs()[0].name: blob})[0].squeeze(axis=0)

        # 输出坐标矫正
        outs = self.cal_outputs(outs,nl,na,model_w,model_h,anchor_grid,stride)

        # 检测框计算
        img_h,img_w,_ = np.shape(img0)
        boxes,confs,ids = self.post_process_opencv(outs,model_h,model_w,img_h,img_w,thred_nms,thred_cond)

        return  boxes,confs,ids
    #单帧视频识别，速度比较快，但视频有拖影情况下返回多次结果
    def traffic_sign(self):
        
#         while True:
        success, img = self.cap.read()
        img0 = cv2.resize(img, self.dim, interpolation = cv2.INTER_AREA)#视频帧压缩
        if success:
            det_boxes,scores,ids = self.infer_img(img0,self.net,self.model_h,self.model_w,self.nl,self.na,self.stride,self.anchor_grid,thred_nms=0.4,thred_cond=0.5)

            for box,score,id in zip(det_boxes,scores,ids):
                    
                label = '%s:%.2f'%(self.dic_labels[id],score)
                self.plot_one_box(box.astype(np.int16), img0, color=(255,0,0), label=label, line_thickness=None)
                return self.dic_labels[id], score,img0
            
            return 0,0,img0
        self.cap.release() 
    #多帧视频检测以后返回的结果，避免视频有拖影返回多次结果，
    def traffic_5sign(self):
        while True:
            trafficsign,score,img = self.traffic_sign()
            if trafficsign !=0:
                if trafficsign == self.last_sign:
                    self.count += 1
                    if self.count == 5:
                        self.count = 0
                        return trafficsign,score,img
                   
            else:
                return 0,0,img
            self.last_sign = trafficsign
        
    
    
    #显示视频，只能在python实验中使用，不能用在按键启动app.py里面，因为有视频显示的话会启动不成功
    def cv2_show(self,name,img):
        cv2.imshow(name,img)
        cv2.waitKey(1)
                


