mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
add scale_factor and bias to description zscore normalization
This commit is contained in:
parent
5cda95f5bf
commit
e3a81db0bb
@ -73,7 +73,7 @@ class EmbeddingsContext:
|
|||||||
def __init__(self, db: SqliteVecQueueDatabase):
|
def __init__(self, db: SqliteVecQueueDatabase):
|
||||||
self.embeddings = Embeddings(db)
|
self.embeddings = Embeddings(db)
|
||||||
self.thumb_stats = ZScoreNormalization()
|
self.thumb_stats = ZScoreNormalization()
|
||||||
self.desc_stats = ZScoreNormalization()
|
self.desc_stats = ZScoreNormalization(scale_factor=2.5, bias=0.5)
|
||||||
|
|
||||||
# load stats from disk
|
# load stats from disk
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class MiniLMEmbedding:
|
|||||||
elif os.path.basename(path) == self.TOKENIZER_FILE:
|
elif os.path.basename(path) == self.TOKENIZER_FILE:
|
||||||
logger.info("Downloading MiniLM tokenizer")
|
logger.info("Downloading MiniLM tokenizer")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.MODEL_NAME, clean_up_tokenization_spaces=False
|
self.MODEL_NAME, clean_up_tokenization_spaces=True
|
||||||
)
|
)
|
||||||
tokenizer.save_pretrained(path)
|
tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ class MiniLMEmbedding:
|
|||||||
def _load_tokenizer(self):
|
def _load_tokenizer(self):
|
||||||
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
|
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
|
||||||
return AutoTokenizer.from_pretrained(
|
return AutoTokenizer.from_pretrained(
|
||||||
tokenizer_path, clean_up_tokenization_spaces=False
|
tokenizer_path, clean_up_tokenization_spaces=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_model(self, path: str, providers: List[str]):
|
def _load_model(self, path: str, providers: List[str]):
|
||||||
|
|||||||
@ -4,12 +4,15 @@ import math
|
|||||||
|
|
||||||
|
|
||||||
class ZScoreNormalization:
|
class ZScoreNormalization:
|
||||||
"""Running Z-score normalization for search distance."""
|
def __init__(self, scale_factor: float = 1.0, bias: float = 0.0):
|
||||||
|
"""Initialize with optional scaling and bias adjustments."""
|
||||||
def __init__(self):
|
"""scale_factor adjusts the magnitude of each score"""
|
||||||
|
"""bias will artificially shift the entire distribution upwards"""
|
||||||
self.n = 0
|
self.n = 0
|
||||||
self.mean = 0
|
self.mean = 0
|
||||||
self.m2 = 0
|
self.m2 = 0
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variance(self):
|
def variance(self):
|
||||||
@ -23,7 +26,10 @@ class ZScoreNormalization:
|
|||||||
self._update(distances)
|
self._update(distances)
|
||||||
if self.stddev == 0:
|
if self.stddev == 0:
|
||||||
return distances
|
return distances
|
||||||
return [(x - self.mean) / self.stddev for x in distances]
|
return [
|
||||||
|
(x - self.mean) / self.stddev * self.scale_factor + self.bias
|
||||||
|
for x in distances
|
||||||
|
]
|
||||||
|
|
||||||
def _update(self, distances: list[float]):
|
def _update(self, distances: list[float]):
|
||||||
for x in distances:
|
for x in distances:
|
||||||
|
|||||||
@ -189,19 +189,9 @@ export default function SearchView({
|
|||||||
|
|
||||||
// confidence score - probably needs tweaking
|
// confidence score - probably needs tweaking
|
||||||
|
|
||||||
const zScoreToConfidence = (score: number, source: string) => {
|
const zScoreToConfidence = (score: number) => {
|
||||||
let midpoint, scale;
|
|
||||||
|
|
||||||
if (source === "thumbnail") {
|
|
||||||
midpoint = 2;
|
|
||||||
scale = 0.5;
|
|
||||||
} else {
|
|
||||||
midpoint = 0.5;
|
|
||||||
scale = 1.5;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sigmoid function: 1 / (1 + e^x)
|
// Sigmoid function: 1 / (1 + e^x)
|
||||||
const confidence = 1 / (1 + Math.exp((score - midpoint) * scale));
|
const confidence = 1 / (1 + Math.exp(score));
|
||||||
|
|
||||||
return Math.round(confidence * 100);
|
return Math.round(confidence * 100);
|
||||||
};
|
};
|
||||||
@ -412,21 +402,13 @@ export default function SearchView({
|
|||||||
) : (
|
) : (
|
||||||
<LuText className="mr-1 size-3" />
|
<LuText className="mr-1 size-3" />
|
||||||
)}
|
)}
|
||||||
{zScoreToConfidence(
|
{zScoreToConfidence(value.search_distance)}%
|
||||||
value.search_distance,
|
|
||||||
value.search_source,
|
|
||||||
)}
|
|
||||||
%
|
|
||||||
</Chip>
|
</Chip>
|
||||||
</TooltipTrigger>
|
</TooltipTrigger>
|
||||||
<TooltipPortal>
|
<TooltipPortal>
|
||||||
<TooltipContent>
|
<TooltipContent>
|
||||||
Matched {value.search_source} at{" "}
|
Matched {value.search_source} at{" "}
|
||||||
{zScoreToConfidence(
|
{zScoreToConfidence(value.search_distance)}%
|
||||||
value.search_distance,
|
|
||||||
value.search_source,
|
|
||||||
)}
|
|
||||||
%
|
|
||||||
</TooltipContent>
|
</TooltipContent>
|
||||||
</TooltipPortal>
|
</TooltipPortal>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user