본문 바로가기

Deep Learning/TF Object Detection API

1.2. Pretrained model 다운로드 및 압축 풀기

320x100
320x100




object_detection_tutorial.ipynb을 살펴보면 전체적인 흐름이 다음과 같다.


  1. 사용하고자 하는 pretrained model의 tar.gz 파일 다운로드

  2. tar.gz 파일 압축 풀기

  3. pretrained model의 계산 그래프 로드

  4. 모델 실행 및 결과 출력

이번 글에서는 모델을 다운받고 압축을 푸는 방법까지 알아보겠다.



1. pretrained model 다운로드


Tensorflow에서 지원하는 object detection model 종류와 다운로드 주소는 detection_model_zoo에서 확인할 수 있다.

detection_model_zoo에서 바로 모델 다운받아서 알집으로 압축 풀어주면 되기는 한다. 가장 쉽고 간단한 방법이다.

근데 이를 코드로 작성해서 사용하는 방법에 대해 알아보도록 하자.

나는 faster_rcnn_resnet50 을 사용해보도록 하겠다.

먼저 코드를 봐보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
BASE_URL = 'http://download.tensorflow.org/models/object_detection/' 
MODEL_NAME = 'faster_rcnn_resnet50_coco_2018_01_28'    
MODEL_FILE = MODEL_NAME + '.tar.gz'                       
                                                           
DOWNLOAD_URL = BASE_URL + MODEL_FILE
DOWNLOAD_PATH = './pretrained_model'
 
import os
if not os.path.isdir(DOWNLOAD_PATH):
    os.mkdir(DOWNLOAD_PATH)
    
import six.moves.urllib as urllib
 
opener = urllib.request.URLopener()
opener.addheader('User-Agent''ballentain')
opener.retrieve(DOWNLOAD_URL, filename = DOWNLOAD_PATH + '/' + MODEL_FILE)
print('모델 다운로드 완료...')                                                          Colored by Color Scripter
cs

  • Line 1 ~ 3
    : 다운받고자하는 모델의 이름을 설정한다. 실제 다운 경로가 coco_2018_01_28까지 잡혀져있기 때문에 모델명 뒤에 붙여줘야한다.

  • Line 5
    : 최종 URL 주소에 해당된다.  
      여기서는 http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz 이다.

  • Line 6 ~ 10
    : tar.gz 파일이 저장될 경로를 설정한 뒤, 지정 경로에 폴더가 없으면 폴더를 새로 만들어주는 과정이다.

  • Line 14 ~ 16
    : 설정한 URL 로부터 사용하고자 하는 모델의 tar.gz 파일을 다운받는 과정이다.

      Line 15의 경우 
      
      urllib.error.HTTPError: HTTP Error 403: Forbidden

      
    에러가 뜨는 현상을 해결하기 위해 추가했다.



다운이 완료되면 ./TF_Object_Detection_API/pretrained_model/faster_rcnn_resnet50_coco_2018_01_28 에 


faster_rcnn_resnet50_coco_2018_01_28.tar.gz 가 다운받아진 것을 확인할 수 있다.





2. tar.gz 파일 압축 풀기


다운받은 tar.gz 파일 안에는 다음과 같은 파일들이 존재한다.


frozen_inference_graph.pb는 pretrained model의 계산 그래프와 가중치를 메모리에 올려서 여러 operator들을 실행하는데에 필요하고,

( = object detection model 실행을 위해 필요하다 )


model.ckpt는 retrain을 위해 필요하다.



압축 푸는 코드를 봐보자.


21
22
23
24
25
import tarfile
tar_file = tarfile.open(DOWNLOAD_PATH + '/' + MODEL_FILE)
tar_file.extrackall(DOWNLOAD_PATH)
      
print('압축 해제 완료...')                   Colored by Color Scripter
cs


  • Line 1
    : tar.gz 파일은 tarfile 모듈을 통해 압축을 풀어줄 수 있다.

  • Line 2 ~ 3
    : tar.gz 파일 압축을 풀어주는 과정이다.
      압축 해제 대상 파일의 경로를 설정해 tarfile.open( )함수로 열어준 후,
      extrackall( )함수로 압축을 풀어준다.









다음은 pretrained model을 다운받고 압축 해제까지는 하는 코드 전체이다.


[ pretrained_model_download.py ]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
BASE_URL = 'http://download.tensorflow.org/models/object_detection/' 
MODEL_NAME = 'faster_rcnn_resnet50_coco_2018_01_28'                 
MODEL_FILE = MODEL_NAME + '.tar.gz'                         
                                                           
 
DOWNLOAD_URL = BASE_URL + MODEL_FILE
DOWNLOAD_PATH = './pretrained_model'
 
import os
if not os.path.isdir(DOWNLOAD_PATH):
    os.mkdir(DOWNLOAD_PATH)
    
import six.moves.urllib as urllib
 
opener = urllib.request.URLopener()
opener.addheader('User-Agent''ballentain')
opener.retrieve(DOWNLOAD_URL, filename = DOWNLOAD_PATH + '/' + MODEL_FILE)
print('모델 다운로드 완료...')
 
 
import tarfile
tar_file = tarfile.open(DOWNLOAD_PATH + '/' + MODEL_FILE)
tar_file.extractall(DOWNLOAD_PATH)
 
print('압축 해제 완료...')Colored by Color Scripter
cs



[Deep Learning/TF Object Detection API] - 1. Tensorflow Object Detection API 다운로드 및 환경설정부터 따라해왔다면,


현재 [ TF_Object_Detection_API 폴더 ]는 다음과 같이 구성되어 있을 것이다.









지금까지 pretrained model 다운로드 및 압축 푸는 방법에 대해 알아보았다.


다음 글에서는 frozen_inference_graph.pb 의 계산 그래프와 가중치값들을 메모리에 올리는 방법에 대해 알아보도록 하겠다.