39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
|
from typing import List
|
||
|
from vision.model import Image
|
||
|
|
||
|
|
||
|
class PrecisionRecall:
|
||
|
def __init__(self, precision, recall, ap):
|
||
|
self.precision = precision
|
||
|
self.recall = recall
|
||
|
self.ap = ap
|
||
|
|
||
|
|
||
|
def get_precision(images: List[Image], test):
|
||
|
return len([i for i in images if i.category == test]) / len(images)
|
||
|
|
||
|
|
||
|
def get_recall(images: List[Image], test, category_count):
|
||
|
return len([i for i in images if i.category == test]) / category_count
|
||
|
|
||
|
|
||
|
def get_pr(images: List[Image], query: Image):
|
||
|
images = sorted(images, key=lambda x: x.distance)
|
||
|
|
||
|
query_category_count = len([i for i in images if i.category == query.category])
|
||
|
|
||
|
p = []
|
||
|
r = []
|
||
|
for i in range(len(images)):
|
||
|
p.append(get_precision(images[:i+1], query.category))
|
||
|
r.append(get_recall(images[:i+1], query.category, query_category_count))
|
||
|
|
||
|
precision_list = []
|
||
|
for index, image in enumerate(images):
|
||
|
if image.category == query.category:
|
||
|
precision_list.append(p[index])
|
||
|
|
||
|
ap = sum(precision_list) / query_category_count
|
||
|
|
||
|
return PrecisionRecall(precision=p, recall=r, ap=ap)
|