dak ブログ

python、rubyなどのプログラミング、MySQL、サーバーの設定などの備忘録。レゴの写真も。

GoogLeNet で出力層の手前の層を特徴ベクトルとして取得

2022-02-20 23:29:55 | 画像処理
GoogLeNet で出力層の手前の層を特徴ベクトルとして取得する方法のメモ。
register_forward_hook() で出力層の手前の層の出力を取得します。
# -*- coding:utf-8 -*-

import sys
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from googlenet_pytorch import GoogLeNet

feature_vector = None

def get_feature_vector(preproc, model, img_file):
    input_image = Image.open(img_file)
    input_tensor = preproc(input_image)
    input_batch = input_tensor.unsqueeze(0)
    
    logits = model(input_batch)
    preds = torch.topk(logits, k=5).indices.squeeze(0).tolist()
    return feature_vector

def forward_hook(module, inputs, outputs):
    global feature_vector
    feature_vector = outputs.detach().clone()[0].tolist()
    
def init():
    preproc = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    labels_map = json.load(open("labels_map.txt"))
    labels_map = [labels_map[str(i)] for i in range(1000)]

    model = GoogLeNet.from_pretrained("googlenet")
    model.eval()

    layers = list(model.children())
    handle = layers[-2].register_forward_hook(forward_hook)
    return preproc, model

def main():
    preproc, model = init()
    img_file = sys.argv[1]
    fv = get_feature_vector(preproc, model, img_file)
    print(fv)

    return 0

if __name__ == '__main__':
    res = main()
    exit(res)


この記事についてブログを書く
« CentOS8 で yum install での... | トップ | PIL でバイトデータから画像... »

画像処理」カテゴリの最新記事