본문 바로가기

Deep Learning/TF Object Detection API

2.2. Mask R-CNN을 Retrain 시켜보자

320x100
320x100



create_tf_record.py를 통해 TFRecord 파일을 만들었으니 이제 Retrain 하는 방법에 대해 알아볼 차례다.


Retrain 방법에 대해 구글링해본 사람이라면 한 번쯤은 이런 생각을 해본 적 있지 않을까 싶다. (뭐.. 나만 삐뚫어진 걸 수도 있지만..ㅎㅎ)


'뭐야, 이게 무슨 API야? 사용성이 너무 떨어지는데?'


그래서 기존에 알려진 방식보다 더 사용성을 높여 보려고 이것 저것 추가해서 작성해봤다.


순서는 다음과 같다.

  1. [pretrained_model 폴더]에 mask_rcnn_resnet101_coco_2018_01_28 모델 다운로드
    모델은 detection_model_zoo에서 다운받으면 된다.

  2. [model_configs 폴더] 생성 후
    ./object_detection/samples/configs 경로에 있는
    mask_rcnn_resnet101_atrous_coco.config 파일 복사해 [model_configs 폴더]로 붙여넣기

  3. [retrained_model 폴더], [exported_model 폴더] 생성

  4. mask_rcnn_resnet101_atrous_coco.config 파일 수정

  5. Utils.py 작성

  6. main.py 작성 후 Retrain 시작

한 개의 script 파일로 Retrain과 Export를 수행하는 코드를 main.py에 구현했고,

main.py 실행에 필요한 함수들을 Utils.py 에 따로 정의해놨다. 전체 코드는 Github에서 다운받을 수 있다.


이제 3번까지는 진행됐다 가정하고 4번 config 파일 수정부터 자세히 알아보도록 하자.





1. mask_rcnn_resnet101_atrous_coco.config 파일 수정


필수적으로 변경해야 하는 부분은 다음과 같다.

  • num_classes

  • max_detections_per_class

  • max_total_detections

  • fine_tune_checkpoint

  • num_steps -> main.py 에서 계속 수정되도록 구현됨

  • train_input_reader 의 input_path 및 label_map_path

  • eval_input_reader 의 input_path 및 label_map_path

이들을 Custom Dataset 환경에 맞게 바꿔주면 된다.

여기에 OOM과 같은 에러를 잡고 더 나은 결과를 얻기 위해 다음을 변경 및 추가했다.
  • image_resizer 변경

  • mask_prediction_num_conv_layers 변경

  • second_stage_batch_size 추가

수정을 완료한 config 파일은 여기에서 확인할 수 있다.


2. Utils.py 작성


필요한 기능은 다음과 같다.
  1. Config 파일의 num_steps 을 동적으로 변경
    : main.py는 Retrain 중간 중간에 모델 저장을 하도록 구현되어 있는데 이를 위해선 num_steps 수정이 필요하다.

    1
    2
    3
    4
    5
    def modifyConfig(pipeline_config_path, step):
        for line in fileinput.input(pipeline_config_path, inplace = True):
            if 'num_steps: ' in line:
                line = line.replace(line, '  num_steps: {}\n' .format(str(step)))
            sys.stdout.write(line)
    cs


    fileinput.input( ) 과 sys.stdout.write( ) 을 모르겠다면

    [Python] - python의 fileinput module로 파일 수정하는 방법을 참고하도록 하자.


  2. ./object_detection/legacy/train.py 읽기 및 수정
    : ./object_detection/legacy 밑에 있는 train.py는 원래대로라면 cmd 창을 통해 매개변수를 넘겨주며 실행시키줘야 한다.
     이 번거로운 과정을 없애보고자 getTrainScript( ) 함수를 설계했다.

     getTrainScript( ) 함수는
     우선 train.py 파일을 읽은 뒤 train_dirpipeline_config_path에 입력되어야되는 정보를 직접 변경하도록 설계되었다. 
    1
    2
    3
    4
    5
    6
    7
    8
    def getTrainScript(train_dir, pipeline_config_path):
        script = open('./object_detection/legacy/train.py').read()
        script = script.replace("flags.DEFINE_string('train_dir', ''",
                                "flags.DEFINE_string('train_dir', '{}'" .format(train_dir))
        script = script.replace("flags.DEFINE_string('pipeline_config_path', ''",
                                "flags.DEFINE_string('pipeline_config_path', '{}'".format(pipeline_config_path))
     
        return script Colored by Color Scripter

  3. ./object_detection/export_inference_graph.py 읽기 및 수정
    : export_inference_graph.py 역시 원래는 train.py처럼 cmd 창을 통해 매개변수를 넘겨주며 실행시키줘야 한다.
     이 번거로운 과정을 없애보고자 getExportScript( ) 함수를 설계했다.

     getExportScript( ) 함수는
     export_inference_graph.py를 읽은 뒤
     pipeline_config_path, trained_checkpoint_prefix, output_directory에 입력되어야되는 정보를 직접 변경하도록 설계되었다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    def getExportScript(pipeline_config_path, step, base_save_path):
        trained_checkpoint_prefix = './retrained_model/model.ckpt-' + str(step)
        output_directory = '{}/Step_{}' .format(base_save_path, str(step))
        if not os.path.isdir(output_directory):
            os.mkdir(output_directory)
        
        script = open('./object_detection/export_inference_graph.py').read()
        
        script = script.replace("flags.DEFINE_string('pipeline_config_path', None",
                                "flags.DEFINE_string('pipeline_config_path', '{}'".format(pipeline_config_path))
        script = script.replace("flags.DEFINE_string('trained_checkpoint_prefix', None",
                                "flags.DEFINE_string('trained_checkpoint_prefix', '{}'".format(trained_checkpoint_prefix))
        script = script.replace("flags.DEFINE_string('output_directory', None",
                                "flags.DEFINE_string('output_directory', '{}'" .format(output_directory))
        
        return script
    cs

  4. Tensorflow 환경 reset
    : tf.app.flags.FLAGS와 default graph를 중간 중간에 초기화해줘야 중복 선언에 의한 에러없이 원활하게 진행될 수 있다.
    1
    2
    3
    4
    5
    6
    7
    8
    def resetEnv():    
        # retrain 과정에서 초기화 해줘야됨
        keys = list(tf.app.flags.FLAGS._flags().keys())
        for key in keys:
            tf.app.flags.FLAGS.__delattr__(key)
     
        # export 과정에서 초기화 해줘야됨
        tf.reset_default_graph()
    cs

Utils.py에 설계된 함수들은 여기에서 확인할 수 있다.


3. main.py 작성


main.py 에서 중요한 포인트는 sys.path.append( )exec( )subprocess.run( ) 이다.
  1. 내 블로그 TF Object Detection API 시리즈를 처음부터 따라왔다 가정하면
    TF_Object_Detection_API 폴더의 경로는 'D:/TF_Object_Detection' 가 될 것이다.
    이 경로를 main.py를 실행하는 동안 임시적으로 시스템 변수에 추가해줘야 경로 에러 없이 원활하게 실행할 수 있다.

    1
    2
    sys.path.append('D:/TF_Object_Detection_API')
    sys.path.append('D:/TF_Object_Detection_API/slim')
    cs


  2. getTrainScript( ) 또는 getExportScript( )를 통해 읽은 script data는 파이썬 내장함수인 exec( )를 통해 실행시켜줄 수 있다.
    사용방법은 다음과 같다.

    1
    2
    train_py_script = getTrainScript( )
    exec(train_py_script)
    cs

  3. 그런데 단순히 exec(train_py_script) 만을 해주면 main.py를 실행시키던 프로세스가 종료되어 버린다는 문제가 발생한다.
    exec(train_py_script) 이후의 작업들도 수행되어야 하는데
    exec(train_py_script) 에서 프로세스가 종료되어 다음 작업을 이어나갈 수 없게 된다.
    이 문제를 해결하고자 subprocess.run( ) 함수를 사용했다.

    1
    2
    train_py_script = getTrainScript( )
    subprocess.run([exec(train_py_script])
    cs


    정확하게는 모르겠지만
    subprocess.run( )을 통해 exec( )를 실행시켜주게 되면 자식 프로세스가 생성돼 exec( )을 실행시키는 것 같다.
    그렇게 되면 exec( ) 실행이 끝난 뒤에 자식 프로세스가 종료될 뿐
    main.py을 실행하는 부모 프로세스에는 영향이 없어 이어지는 직업들을 계속해서 수행할 수 있게 되는 것 같다.


아직 해결하지 못한 문제점이 있는데

subprocess.run( )을 해주면 이상하게 에러가 발생한다는 점이다.

근데 이 에러가 내가 원하는 API 단순화 작업에 영향을 주는 심각한 에러는 아닌듯 해서 일단은 try & except을 사용해 잡아줬다.

추후에 발생 원인 및 해결 방법을 알게된다면 업데이트하도록 하겠다.


설계된 main.py 는 여기에서 확인할 수 있다.


4. main.py 실행


모든 준비가 끝났으니 이제 main.py 를 실행해 Retrain을 시작하면 된다!


다음과 같이 실행해주면 retrain 및 export model 이 이뤄지는 것을 확인할 수 있다.