320x100
320x100
object_detection_tutorial.ipynb을 살펴보면 전체적인 흐름이 다음과 같다.
- 사용하고자 하는 pretrained model의 tar.gz 파일 다운로드
- tar.gz 파일 압축 풀기
- pretrained model의 계산 그래프 로드
- 모델 실행 및 결과 출력
이번 글에서는 모델을 다운받고 압축을 푸는 방법까지 알아보겠다.
1. pretrained model 다운로드
Tensorflow에서 지원하는 object detection model 종류와 다운로드 주소는 detection_model_zoo에서 확인할 수 있다.
detection_model_zoo에서 바로 모델 다운받아서 알집으로 압축 풀어주면 되기는 한다. 가장 쉽고 간단한 방법이다.
근데 이를 코드로 작성해서 사용하는 방법에 대해 알아보도록 하자.
나는 faster_rcnn_resnet50 을 사용해보도록 하겠다.
먼저 코드를 봐보자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | BASE_URL = 'http://download.tensorflow.org/models/object_detection/' MODEL_NAME = 'faster_rcnn_resnet50_coco_2018_01_28' MODEL_FILE = MODEL_NAME + '.tar.gz' DOWNLOAD_URL = BASE_URL + MODEL_FILE DOWNLOAD_PATH = './pretrained_model' import os if not os.path.isdir(DOWNLOAD_PATH): os.mkdir(DOWNLOAD_PATH) import six.moves.urllib as urllib opener = urllib.request.URLopener() opener.addheader('User-Agent', 'ballentain') opener.retrieve(DOWNLOAD_URL, filename = DOWNLOAD_PATH + '/' + MODEL_FILE) | cs |
- Line 1 ~ 3
: 다운받고자하는 모델의 이름을 설정한다. 실제 다운 경로가 coco_2018_01_28까지 잡혀져있기 때문에 모델명 뒤에 붙여줘야한다. - Line 5
: 최종 URL 주소에 해당된다.
여기서는 http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz 이다. - Line 6 ~ 10
: tar.gz 파일이 저장될 경로를 설정한 뒤, 지정 경로에 폴더가 없으면 폴더를 새로 만들어주는 과정이다. - Line 14 ~ 16
: 설정한 URL 로부터 사용하고자 하는 모델의 tar.gz 파일을 다운받는 과정이다.
Line 15의 경우
urllib.error.HTTPError: HTTP Error 403: Forbidden
에러가 뜨는 현상을 해결하기 위해 추가했다.
다운이 완료되면 ./TF_Object_Detection_API/pretrained_model/faster_rcnn_resnet50_coco_2018_01_28 에
faster_rcnn_resnet50_coco_2018_01_28.tar.gz 가 다운받아진 것을 확인할 수 있다.
2. tar.gz 파일 압축 풀기
다운받은 tar.gz 파일 안에는 다음과 같은 파일들이 존재한다.
frozen_inference_graph.pb는 pretrained model의 계산 그래프와 가중치를 메모리에 올려서 여러 operator들을 실행하는데에 필요하고,
( = object detection model 실행을 위해 필요하다 )
model.ckpt는 retrain을 위해 필요하다.
압축 푸는 코드를 봐보자.
21 22 23 24 25 | import tarfile tar_file = tarfile.open(DOWNLOAD_PATH + '/' + MODEL_FILE) tar_file.extrackall(DOWNLOAD_PATH) | cs |
- Line 1
: tar.gz 파일은 tarfile 모듈을 통해 압축을 풀어줄 수 있다. - Line 2 ~ 3
: tar.gz 파일 압축을 풀어주는 과정이다.
압축 해제 대상 파일의 경로를 설정해 tarfile.open( )함수로 열어준 후,
extrackall( )함수로 압축을 풀어준다.
다음은 pretrained model을 다운받고 압축 해제까지는 하는 코드 전체이다.
[ pretrained_model_download.py ]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | BASE_URL = 'http://download.tensorflow.org/models/object_detection/' MODEL_NAME = 'faster_rcnn_resnet50_coco_2018_01_28' MODEL_FILE = MODEL_NAME + '.tar.gz' DOWNLOAD_URL = BASE_URL + MODEL_FILE DOWNLOAD_PATH = './pretrained_model' import os if not os.path.isdir(DOWNLOAD_PATH): os.mkdir(DOWNLOAD_PATH) import six.moves.urllib as urllib opener = urllib.request.URLopener() opener.addheader('User-Agent', 'ballentain') opener.retrieve(DOWNLOAD_URL, filename = DOWNLOAD_PATH + '/' + MODEL_FILE) print('모델 다운로드 완료...') import tarfile tar_file = tarfile.open(DOWNLOAD_PATH + '/' + MODEL_FILE) tar_file.extractall(DOWNLOAD_PATH) | cs |
[Deep Learning/TF Object Detection API] - 1. Tensorflow Object Detection API 다운로드 및 환경설정부터 따라해왔다면,
현재 [ TF_Object_Detection_API 폴더 ]는 다음과 같이 구성되어 있을 것이다.
지금까지 pretrained model 다운로드 및 압축 푸는 방법에 대해 알아보았다.
다음 글에서는 frozen_inference_graph.pb 의 계산 그래프와 가중치값들을 메모리에 올리는 방법에 대해 알아보도록 하겠다.
'Deep Learning > TF Object Detection API' 카테고리의 다른 글
2.1. Custom Dataset으로 TFRecord 파일 만들기 (0) | 2020.03.19 |
---|---|
2. Tensorflow Object Detection Model Retrain 방법에 대해 알아보자 (0) | 2020.03.13 |
1.5. draw_bounding_boxes 함수 설계 및 inference 결과 확인 (5) | 2019.07.30 |
1.4. run_inference_for_single_image 함수 설계 (0) | 2019.07.05 |
1.3. Pretrained model의 계산 그래프 로드 (0) | 2019.07.03 |