본문 바로가기

Deep Learning/TF Object Detection API

1.5. draw_bounding_boxes 함수 설계 및 inference 결과 확인

320x100
320x100



pretrained model을 실행시켜서 결과 받아보는 과정까지 알아보았다.


이제는 detecting 결과를 이미지에 표현해주기만 하면 된다.



튜토리얼 코드를 봐보면 


./object_detection/utils/visualization_utils.py 를 사용해서 그려주는 걸 볼 수 있는데


그냥 경험삼아 직접 한 번 구현해보고 싶어서 draw_bounding_boxes 함수를 만들었다. 






1. class_info.txt 만들기


코드 설명에 앞서 class_info.txt를 만들어줄 필요가 있다.


draw_bounding_boxes 함수는 class에 대한 정보를 text가 아닌 int로 반환해준다.



따라서 각 숫자가 어떤 class를 뜻하지는 지에 대한 정보가 필요하다.


visualization_utils.py 에서는 


./object_detection/data/mscoco_label_map.pbtxt 를 기반으로 class info를 읽어들여 draw해주고 있어


나 또한 mscoco_label_map.pbtxt 를 참고해 class_info.txt를 만들어주었다.





2. draw bounding boxes


코드를 살펴보자.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def draw_bounding_boxes(img, output_dict, class_info):
height, width, _ = img.shape
 
    obj_index = output_dict['detection_scores'> 0.5
    
    scores = output_dict['detection_scores'][obj_index]
    boxes = output_dict['detection_boxes'][obj_index]
    classes = output_dict['detection_classes'][obj_index]
 
    for box, cls, score in zip(boxes, classes, scores):
        # draw bounding box
        img = cv2.rectangle(img,
                            (int(box[1* width), int(box[0* height)),
                            (int(box[3* width), int(box[2* height)), class_info[cls][1], 8)
 
        # put class name & percentage
        object_info = class_info[cls][0+ ': ' + str(int(score * 100)) + '%'
        text_size, _ = cv2.getTextSize(text = object_info,
                                       fontFace = cv2.FONT_HERSHEY_SIMPLEX,
                                       fontScale = 0.9, thickness = 2)
        img = cv2.rectangle(img,
                            (int(box[1* width), int(box[0* height) - 25),
                            (int(box[1* width) + text_size[0], int(box[0* height)),
                            class_info[cls][1], -1)
        img = cv2.putText(img,
                          object_info,
                          (int(box[1* width), int(box[0* height)),
                          cv2.FONT_HERSHEY_SIMPLEX, 0.9, (000), 2)
        
    return img
 
class_info = {}
= open('class_info.txt''r')
for line in f:
    info = line.split(', ')
 
    class_index = int(info[0])
    class_name = info[1]
    color = (int(info[2][1:]), int(info[3]), int(info[4].strip()[:-1]))    
    
    class_info[class_index] = [class_name, color]
cs


  • Line 3
    : detection_score가 0.5 이상인 것만 그리기 위해 threshold를 걸어주었다.

  • Line 11 ~ 13
    : bounding box 정보를 원본이미지에 대한 비율값으로 주기 때문에
     결과값을 height, width에 곱해줘서 bounding box를 그려주면 된다.

     또한 class_info를 정수가 입력되면 관련 정보가 반환되도록 딕셔너리 형태로 만들었기 때문에
     class_info[cls]와 같이 사용해주면 된다.

  • Line 17 ~ 28
    : class name을 같이 표현해주기 위해 추가해줬다.




다음은 코드 전체와 detecting 결과 이미지이다.









pretrained object detection model을 사용해서 원하는 이미지를 detecting하는 방법에 대해 알아보았다.


다음 글에서는 retrain하는 방법에 대해 알아보도록 하겠다.



  • ㅇㅇ 2020.02.12 17:08

    Traceback (most recent call last):
    File "<pyshell#5>", line 4, in <module>
    class_index = int(info[0])
    ValueError: invalid literal for int() with base 10: '\n'

    이런 오류가 나면 어떡하죵?ㅠ

    • ballentain 2020.02.12 17:32 신고

      ValueError: invalid literal for int() with base 10: '\n'

      이 에러 메세지는
      '\n'를 int형으로 변환하려고해서 뜨는 에러에요.

      결국엔 info[0]에 '\n'에 할당됐다 소리인데, 제가 작성한 코드로는 info[0]에 '\n'이 할당될 일이 없어요.

      info = line.split(', ')
      부분 코드에 문제 없는 지 한 번 확인해보시고, 그래도 해결이 안 되면 작성하신 코드 남겨주세요!

    • fakecan 2020.06.23 19:15 신고

      지금은 해결하셨는지 모르겠지만 위에 그대로 복사하셨으면 각 번호 사이에 한 줄씩 띄워진 부분을 모두 지우고 하시면 될 것 같습니다. 코드 상에서는 , 과 띄어쓰기로 스플릿하는데 \n 처리는 없기에 나타난 것 같네요.

    • stopwater 2020.07.22 11:10

      위에 class_info.txt 만드실때 아마 엔터키 들어가서 줄띄워져있을거에요.

  • stopwater 2020.07.22 11:12

    글보고 너무 많은 도움이 되었습니다. 친절한 설명 다시 한번 감사드립니다!!
    소스 맨밑 부분에
    cv2.waitKey()
    cv2.destroyAllWindows()
    추가 해주시면 좋을거같습니다.