Use slowapi as the limiter

This commit is contained in:
Rui Alves 2024-09-21 12:00:10 +01:00
parent fb804bac4e
commit 6dadeeb488
4 changed files with 27 additions and 21 deletions

View File

@ -2,6 +2,7 @@ click == 8.1.*
# FastAPI # FastAPI
starlette-context == 0.3.6 starlette-context == 0.3.6
fastapi == 0.115.0 fastapi == 0.115.0
slowapi == 0.1.9
imutils == 0.5.* imutils == 0.5.*
joserfc == 1.0.* joserfc == 1.0.*
markupsafe == 2.1.* markupsafe == 2.1.*

View File

@ -16,6 +16,7 @@ from fastapi import APIRouter, Request, Response
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from joserfc import jwt from joserfc import jwt
from peewee import DoesNotExist from peewee import DoesNotExist
from slowapi import Limiter
from frigate.api.defs.app_body import ( from frigate.api.defs.app_body import (
AppPostLoginBody, AppPostLoginBody,
@ -74,17 +75,6 @@ def get_remote_addr(request: Request):
return request.remote_addr or "127.0.0.1" return request.remote_addr or "127.0.0.1"
# TODO: Rui
# limiter = Limiter(
# get_remote_addr,
# storage_uri="memory://",
# )
def get_rate_limit(request: Request):
return request.app.frigate_config.auth.failed_login_rate_limit
def get_jwt_secret() -> str: def get_jwt_secret() -> str:
jwt_secret = None jwt_secret = None
# check env var # check env var
@ -306,9 +296,17 @@ def logout(request: Request):
return response return response
limiter = Limiter(key_func=get_remote_addr)
def get_rate_limit(request: Request):
return request.app.frigate_config.auth.failed_login_rate_limit
@router.post("/login") @router.post("/login")
# TODO: Rui Implement limiter for FastAPI # Ideally, this would be a decorator @limiter.limit(limit_value=get_rate_limit) but that way the request object is not passed to the method
# @limiter.limit(get_rate_limit, deduct_when=lambda response: response.status_code == 400) # See: https://github.com/laurentS/slowapi/issues/41
# @limiter.limit(limit_value=get_rate_limit)
def login(request: Request, body: AppPostLoginBody): def login(request: Request, body: AppPostLoginBody):
JWT_COOKIE_NAME = request.app.frigate_config.auth.cookie_name JWT_COOKIE_NAME = request.app.frigate_config.auth.cookie_name
JWT_COOKIE_SECURE = request.app.frigate_config.auth.cookie_secure JWT_COOKIE_SECURE = request.app.frigate_config.auth.cookie_secure
@ -325,7 +323,7 @@ def login(request: Request, body: AppPostLoginBody):
if verify_password(password, password_hash): if verify_password(password, password_hash):
expiration = int(time.time()) + JWT_SESSION_LENGTH expiration = int(time.time()) + JWT_SESSION_LENGTH
encoded_jwt = create_encoded_jwt(user, expiration, request.app.jwt_token) encoded_jwt = create_encoded_jwt(user, expiration, request.app.jwt_token)
response = Response({}, 200) response = Response("", 200)
set_jwt_cookie( set_jwt_cookie(
response, JWT_COOKIE_NAME, encoded_jwt, expiration, JWT_COOKIE_SECURE response, JWT_COOKIE_NAME, encoded_jwt, expiration, JWT_COOKIE_SECURE
) )

View File

@ -4,12 +4,15 @@ from typing import Optional
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from playhouse.sqliteq import SqliteQueueDatabase from playhouse.sqliteq import SqliteQueueDatabase
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from starlette_context import middleware, plugins from starlette_context import middleware, plugins
from starlette_context.plugins import Plugin from starlette_context.plugins import Plugin
from frigate.api import app as main_app from frigate.api import app as main_app
from frigate.api import auth, event, export, media, notification, preview, review from frigate.api import auth, event, export, media, notification, preview, review
from frigate.api.auth import get_jwt_secret from frigate.api.auth import get_jwt_secret, limiter
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.events.external import ExternalEventProcessor from frigate.events.external import ExternalEventProcessor
from frigate.plus import PlusApi from frigate.plus import PlusApi
@ -53,6 +56,7 @@ def create_fastapi_app(
) )
# update the request_address with the x-forwarded-for header from nginx # update the request_address with the x-forwarded-for header from nginx
# https://starlette-context.readthedocs.io/en/latest/plugins.html#forwarded-for
app.add_middleware( app.add_middleware(
middleware.ContextMiddleware, middleware.ContextMiddleware,
plugins=(plugins.ForwardedForPlugin(),), plugins=(plugins.ForwardedForPlugin(),),
@ -75,11 +79,9 @@ def create_fastapi_app(
database.close() database.close()
return response return response
# TODO: Rui app.state.limiter = limiter
# initialize the rate limiter for the login endpoint app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# limiter.init_app(app) app.add_middleware(SlowAPIMiddleware)
# if frigate_config.auth.failed_login_rate_limit is None:
# limiter.enabled = False
# Routes # Routes
app.include_router(main_app.router) app.include_router(main_app.router)

View File

@ -18,8 +18,8 @@ from fastapi import APIRouter, Path, Query, Request, Response
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from peewee import DoesNotExist, fn from peewee import DoesNotExist, fn
from tzlocal import get_localzone_name from tzlocal import get_localzone_name
from werkzeug.utils import secure_filename
from frigate.api.defs.media_query_parameters import MediaLatestFrameQueryParams
from frigate.api.defs.tags import Tags from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import ( from frigate.const import (
@ -39,6 +39,11 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=[Tags.media]) router = APIRouter(tags=[Tags.media])
# TODO: Rui Implement or get from existing 3rd party
def secure_filename(file_name: str):
return file_name
@router.get("/media/camera/{camera_name}") @router.get("/media/camera/{camera_name}")
def mjpeg_feed( def mjpeg_feed(
request: Request, request: Request,