Python使用余弦相似度比对特征,numpy余弦相似度比对特征

先介绍下背景:我们有个项目是基于深度神经网络的菜品识别类项目,简易流程可以理解为:

  1. 通过深度学习模型,提取菜品图像的菜品特征。
  2. 将特征存储到特征库中。
  3. 获取摄像头下方的菜品图像,提取特征。
  4. 拿到摄像头下发图像的特征,通过余弦相似度对比,得出相似度。

先来一段特征注册的代码逻辑:

import numpy as np
import pickle

def featRegister(tag):
    '''
    repo 为字典类型, 特征库
    {
        '包子': array([[2.9554849,..,0.45501372],[1.6043618,0.,...,0.]], dtype=float32), 
        '鱼头': array([[1.6043618,..., 0.]], dtype=float32), 
        '米饭': array([[0., 0., 2.9554849 , ..., 0., 0.,0.45501372]], dtype=float32)
    }
    '''
    # 1. 组建特征库
    repo = {}
    for dishFeaId, featureSrc in tag.items():
        if not featureSrc:
            continue
        feat = pickle.loads(featureSrc.encode('latin1'))
        repo[dishFeaId] = np.array([feat])
        
    # 2. 计算 active_repo_feats all_repo_feats repo_count
    repo_concat = np.concatenate([repo[key] for key in repo], 0)
    all_repo_feats = np.array([x / np.linalg.norm(x) for x in repo_concat])
    active_repo_ids = []
    for id in repo:
        active_repo_ids += [id] * len(repo[id])
    repo_count = len(repo)
    
    return all_repo_feats, active_repo_ids, repo_count

组成了一个repo的dict,并使用了pickle进行序列化,通过numpy进行拼接。

接下来,我们来模拟一个需要比对的特征值:

feats = []
for dishFeaId, featureSrc in featsN.items():
	feat_index = pickle.loads(featureSrc.encode('latin1'))
	feats.append(np.array(feat_index))

然后就是通过与余弦的相似度进行特征比对的逻辑:

import uuid
import numpy as np

def featsCompare(feats, feat_boxes, all_repo_feats, active_repo_ids, repo_count):
    if repo_count <= 0:
        return None
    
    dishId_feaIds_ids = []
    dishId_feaId_scores = []
    feats = [x / np.linalg.norm(x) for x in feats]
    cos_distance = np.matmul(feats, all_repo_feats.transpose((1, 0)))
    ids_score_sorted = np.sort(cos_distance)
    ids_score_sorted = [x[::-1] for x in ids_score_sorted]
    ids_indices_sorted = np.argsort(cos_distance)
    ids_indices_sorted = [x[::-1] for x in ids_indices_sorted]
    for i in range(len(ids_indices_sorted)):
        food_ids, food_scores = [], []
        all_repo_ids_sorted = np.array(active_repo_ids)[list(ids_indices_sorted[i])]
        for j in range(len(all_repo_ids_sorted)):
            if all_repo_ids_sorted[j] not in food_ids:
                food_ids.append(all_repo_ids_sorted[j])
                food_scores.append(ids_score_sorted[i][j])
        for x in food_ids:
            dishId_feaIds_ids.append(x)
        for x in food_scores:
            dishId_feaId_scores.append(round(float(100 * x), 2))
            
    resultCount = len(feat_boxes)
    resultList = []

    for i in range(resultCount):
        resultN = feat_boxes[i]
        resultN['dishes'] = []
        idList, scoreList = [], []
        for j in range(repo_count):
            k = j + i * repo_count
            idList.append(dishId_feaIds_ids[k])
            scoreList.append(dishId_feaId_scores[k])

        # 筛选dishId与feaId、score对应关系
        dishId_feaId_score = {}
        for m in range(len(idList)):
            dishId_feaId = idList[m]
            dishId_feaIds = dishId_feaId.split('_')
            if len(dishId_feaIds) == 1:
                dishId = dishId_feaIds[0]
                feaId = dishId
            elif len(dishId_feaIds) == 2:
                dishId = dishId_feaIds[0]
                feaId = dishId_feaIds[1]
            else:
                uuId = str(uuid.uuid1())
                dishId = uuId.replace('-', '')
                uuId = str(uuid.uuid1())
                feaId = uuId.replace('-', '')
            if dishId not in dishId_feaId_score:
                # 排名最高的dishIds,及其对应的feaId和score 的汇总
                dishId_feaId_score[dishId] = [feaId, scoreList[m]]
            if len(dishId_feaId_score) >= 10:
                break
            
        dishes = []
        for dishId in dishId_feaId_score:
            feaId = dishId_feaId_score[dishId][0]
            score = dishId_feaId_score[dishId][1]
            topN = {
                'dishId': dishId,
                'score': score,
            }
            dishes.append(topN)
        resultN['dishes'] = dishes
        resultList.append(resultN)
        
    return resultList

以上有一些是为了测试的业务逻辑,但整体逻辑和主要代码是没什么问题的。

以上,仅供参考。

展开阅读全文

页面更新:2024-06-20

标签:余弦   特征值   特征   神经网络   摄像头   深度   逻辑   图像   代码   项目

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号

Top