mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-16 16:15:22 +03:00
ensure tokenizer and feature extractor are correctly loaded
This commit is contained in:
parent
896dad5cc9
commit
dd6b0c9bcd
@ -119,21 +119,27 @@ class GenericONNXEmbedding:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _load_tokenizer(self):
|
def _load_tokenizer(self):
|
||||||
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
|
if not self.tokenizer:
|
||||||
return AutoTokenizer.from_pretrained(
|
tokenizer_path = os.path.join(
|
||||||
self.model_name,
|
f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer"
|
||||||
cache_dir=tokenizer_path,
|
)
|
||||||
trust_remote_code=True,
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
clean_up_tokenization_spaces=True,
|
self.model_name,
|
||||||
)
|
cache_dir=tokenizer_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_feature_extractor(self):
|
def _load_feature_extractor(self):
|
||||||
feature_extractor_path = os.path.join(
|
if not self.feature_extractor:
|
||||||
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
|
feature_extractor_path = os.path.join(
|
||||||
)
|
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
|
||||||
return AutoFeatureExtractor.from_pretrained(
|
)
|
||||||
self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
)
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
cache_dir=feature_extractor_path,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_model(self, path: str, providers: List[str]):
|
def _load_model(self, path: str, providers: List[str]):
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user