Deep Learning/TF Object Detection API
1.5. draw_bounding_boxes 함수 설계 및 inference 결과 확인
ballentain
2019. 7. 30. 21:49
pretrained model을 실행시켜서 결과 받아보는 과정까지 알아보았다.
이제는 detecting 결과를 이미지에 표현해주기만 하면 된다.
튜토리얼 코드를 봐보면
./object_detection/utils/visualization_utils.py 를 사용해서 그려주는 걸 볼 수 있는데
그냥 경험삼아 직접 한 번 구현해보고 싶어서 draw_bounding_boxes 함수를 만들었다.
1. class_info.txt 만들기
코드 설명에 앞서 class_info.txt를 만들어줄 필요가 있다.
draw_bounding_boxes 함수는 class에 대한 정보를 text가 아닌 int로 반환해준다.
따라서 각 숫자가 어떤 class를 뜻하지는 지에 대한 정보가 필요하다.
visualization_utils.py 에서는
./object_detection/data/mscoco_label_map.pbtxt 를 기반으로 class info를 읽어들여 draw해주고 있어
나 또한 mscoco_label_map.pbtxt 를 참고해 class_info.txt를 만들어주었다.
2. draw bounding boxes
코드를 살펴보자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | def draw_bounding_boxes(img, output_dict, class_info): height, width, _ = img.shape obj_index = output_dict['detection_scores'] > 0.5 scores = output_dict['detection_scores'][obj_index] boxes = output_dict['detection_boxes'][obj_index] classes = output_dict['detection_classes'][obj_index] for box, cls, score in zip(boxes, classes, scores): # draw bounding box img = cv2.rectangle(img, (int(box[1] * width), int(box[0] * height)), (int(box[3] * width), int(box[2] * height)), class_info[cls][1], 8) # put class name & percentage object_info = class_info[cls][0] + ': ' + str(int(score * 100)) + '%' text_size, _ = cv2.getTextSize(text = object_info, fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.9, thickness = 2) img = cv2.rectangle(img, (int(box[1] * width), int(box[0] * height) - 25), (int(box[1] * width) + text_size[0], int(box[0] * height)), class_info[cls][1], -1) img = cv2.putText(img, object_info, (int(box[1] * width), int(box[0] * height)), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 0), 2) return img class_info = {} f = open('class_info.txt', 'r') for line in f: info = line.split(', ') class_index = int(info[0]) class_name = info[1] color = (int(info[2][1:]), int(info[3]), int(info[4].strip()[:-1])) class_info[class_index] = [class_name, color] f.close() Colored by Color Scripter | cs |
- Line 3
: detection_score가 0.5 이상인 것만 그리기 위해 threshold를 걸어주었다. - Line 11 ~ 13
: bounding box 정보를 원본이미지에 대한 비율값으로 주기 때문에
결과값을 height, width에 곱해줘서 bounding box를 그려주면 된다.
또한 class_info를 정수가 입력되면 관련 정보가 반환되도록 딕셔너리 형태로 만들었기 때문에
class_info[cls]와 같이 사용해주면 된다. - Line 17 ~ 28
: class name을 같이 표현해주기 위해 추가해줬다.
다음은 코드 전체와 detecting 결과 이미지이다.
pretrained object detection model을 사용해서 원하는 이미지를 detecting하는 방법에 대해 알아보았다.
다음 글에서는 retrain하는 방법에 대해 알아보도록 하겠다.