先介绍下背景:我们有个项目是基于深度神经网络的菜品识别类项目,简易流程可以理解为:
先来一段特征注册的代码逻辑:
import numpy as np
import pickle
def featRegister(tag):
'''
repo 为字典类型, 特征库
{
'包子': array([[2.9554849,..,0.45501372],[1.6043618,0.,...,0.]], dtype=float32),
'鱼头': array([[1.6043618,..., 0.]], dtype=float32),
'米饭': array([[0., 0., 2.9554849 , ..., 0., 0.,0.45501372]], dtype=float32)
}
'''
# 1. 组建特征库
repo = {}
for dishFeaId, featureSrc in tag.items():
if not featureSrc:
continue
feat = pickle.loads(featureSrc.encode('latin1'))
repo[dishFeaId] = np.array([feat])
# 2. 计算 active_repo_feats all_repo_feats repo_count
repo_concat = np.concatenate([repo[key] for key in repo], 0)
all_repo_feats = np.array([x / np.linalg.norm(x) for x in repo_concat])
active_repo_ids = []
for id in repo:
active_repo_ids += [id] * len(repo[id])
repo_count = len(repo)
return all_repo_feats, active_repo_ids, repo_count
组成了一个repo的dict,并使用了pickle进行序列化,通过numpy进行拼接。
接下来,我们来模拟一个需要比对的特征值:
feats = []
for dishFeaId, featureSrc in featsN.items():
feat_index = pickle.loads(featureSrc.encode('latin1'))
feats.append(np.array(feat_index))
然后就是通过与余弦的相似度进行特征比对的逻辑:
import uuid
import numpy as np
def featsCompare(feats, feat_boxes, all_repo_feats, active_repo_ids, repo_count):
if repo_count <= 0:
return None
dishId_feaIds_ids = []
dishId_feaId_scores = []
feats = [x / np.linalg.norm(x) for x in feats]
cos_distance = np.matmul(feats, all_repo_feats.transpose((1, 0)))
ids_score_sorted = np.sort(cos_distance)
ids_score_sorted = [x[::-1] for x in ids_score_sorted]
ids_indices_sorted = np.argsort(cos_distance)
ids_indices_sorted = [x[::-1] for x in ids_indices_sorted]
for i in range(len(ids_indices_sorted)):
food_ids, food_scores = [], []
all_repo_ids_sorted = np.array(active_repo_ids)[list(ids_indices_sorted[i])]
for j in range(len(all_repo_ids_sorted)):
if all_repo_ids_sorted[j] not in food_ids:
food_ids.append(all_repo_ids_sorted[j])
food_scores.append(ids_score_sorted[i][j])
for x in food_ids:
dishId_feaIds_ids.append(x)
for x in food_scores:
dishId_feaId_scores.append(round(float(100 * x), 2))
resultCount = len(feat_boxes)
resultList = []
for i in range(resultCount):
resultN = feat_boxes[i]
resultN['dishes'] = []
idList, scoreList = [], []
for j in range(repo_count):
k = j + i * repo_count
idList.append(dishId_feaIds_ids[k])
scoreList.append(dishId_feaId_scores[k])
# 筛选dishId与feaId、score对应关系
dishId_feaId_score = {}
for m in range(len(idList)):
dishId_feaId = idList[m]
dishId_feaIds = dishId_feaId.split('_')
if len(dishId_feaIds) == 1:
dishId = dishId_feaIds[0]
feaId = dishId
elif len(dishId_feaIds) == 2:
dishId = dishId_feaIds[0]
feaId = dishId_feaIds[1]
else:
uuId = str(uuid.uuid1())
dishId = uuId.replace('-', '')
uuId = str(uuid.uuid1())
feaId = uuId.replace('-', '')
if dishId not in dishId_feaId_score:
# 排名最高的dishIds,及其对应的feaId和score 的汇总
dishId_feaId_score[dishId] = [feaId, scoreList[m]]
if len(dishId_feaId_score) >= 10:
break
dishes = []
for dishId in dishId_feaId_score:
feaId = dishId_feaId_score[dishId][0]
score = dishId_feaId_score[dishId][1]
topN = {
'dishId': dishId,
'score': score,
}
dishes.append(topN)
resultN['dishes'] = dishes
resultList.append(resultN)
return resultList
以上有一些是为了测试的业务逻辑,但整体逻辑和主要代码是没什么问题的。
以上,仅供参考。
页面更新:2024-06-20
本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828
© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号