函数
路径
列出指定目录及所有子目录下的文件夹
import os
videos_path = 'D:\Videos'
dirs = []
def joindir(path: str):
os.chdir(path)
for p in [os.path.join(path, dir) for dir in os.listdir() if os.path.isdir(os.path.join(path, dir))]:
dirs.append(p)
joindir(p)
joindir(videos_path)
print(dirs)
列出指定目录及所有子目录下的文件
import os
videos_path = 'D:\Videos'
def file(filepath):
print(filepath)
def joindir(path: str):
os.chdir(path)
for p in os.listdir():
joinpath = os.path.join(path, p)
if os.path.isdir(joinpath):
joindir(joinpath)
elif os.path.isfile(joinpath):
file(joinpath)
else:
raise FileNotFoundError(joinpath)
joindir(videos_path)
类
单例模式
class Singleton:
def __new__(cls):
if not hasattr(cls, "_singleton"):
setattr(cls, "_singleton", super().__new__(cls))
return getattr(cls, "_singleton")
雪花算法 ID 生成器
import asyncio
import threading
from datetime import UTC, datetime
from logging import getLogger
from time import sleep
logger = getLogger()
class SnowflakeID:
@staticmethod
def __current_timestamp():
"""获取毫秒时间戳"""
return int(datetime.now(UTC).timestamp() * 1000)
def __init__(
self,
worker_id: int,
worker_id_bits: int = 10,
sequence_bits: int = 12,
start_timestamp: int = int(
datetime(2024, 1, 1, tzinfo=UTC).timestamp() * 1e3
),
) -> None:
message = f"Created Snowflake ID generator with worker ID {worker_id}."
logger.info(message)
self._worker_id_bits = worker_id_bits
self._max_worker_id = 2**worker_id_bits - 1
if self.max_worker_id < worker_id:
msg = "Exceeds maximum value"
raise ValueError(msg)
if worker_id < 1:
msg = "Below minimum value"
raise ValueError(msg)
self._worker_id = worker_id - 1
# 同一毫秒内生成多个 ID 的序号
self._sequence_bits = sequence_bits
# 同一毫秒内最多生成多个 ID的数量
self._max_millisecond_count = 2**sequence_bits - 1
self._start_timestamp = start_timestamp
self._last_timestamp = -1 # 上次生成 ID 的时间戳
self._times = 0
self._lock = threading.Lock()
self._alock = asyncio.Lock()
@property
def worker_id_bits(self) -> int:
return self._worker_id_bits
@property
def max_worker_id(self) -> int:
return self._max_worker_id + 1
@property
def sequence_bits(self) -> int:
return self._sequence_bits
@property
def max_millisecond_count(self) -> int:
return self._max_millisecond_count + 1
@property
def start_timestamp(self) -> int:
return self._start_timestamp
def __get(self, passed_timestamp: int):
result = passed_timestamp << (
self._worker_id_bits + self._sequence_bits
)
result |= self._worker_id << self._sequence_bits
result |= self._times
return result
def gen(self) -> int:
with self._lock:
assert self._worker_id <= self.max_worker_id
timestamp = self.__current_timestamp()
# 如果时间回拨, 等待修正时间
if timestamp < self._last_timestamp:
sleep(self._last_timestamp - timestamp)
self._times = 0
# 同一毫秒内生成
if timestamp == self._last_timestamp:
self._times = (self._times + 1) & self._max_millisecond_count
# 溢出, 阻塞到下一秒
if self._times == 0:
sleep(0.001)
timestamp = self.__current_timestamp()
# 非同一毫秒
else:
self._times = 0
self._last_timestamp = timestamp
return self.__get(timestamp - self._start_timestamp)
async def agen(self):
async with self._alock:
worker_id = self._worker_id
assert worker_id <= self.max_worker_id
timestamp = self.__current_timestamp()
if timestamp < self._last_timestamp:
await asyncio.sleep(self._last_timestamp - timestamp)
self._times = 0
if timestamp == self._last_timestamp:
self._times = (self._times + 1) & self._max_millisecond_count
if self._times == 0:
await asyncio.sleep(0.001)
timestamp = self.__current_timestamp()
else:
self._times = 0
self._last_timestamp = timestamp
return self.__get(timestamp - self._start_timestamp)
def decode(self, id: int): # noqa: A002
sequence_mask = (1 << self._sequence_bits) - 1
worker_id_mask = (1 << self._worker_id_bits) - 1
sequence = id & sequence_mask
worker_id = (id >> self._sequence_bits) & worker_id_mask
passed_timestamp = id >> (self._worker_id_bits + self._sequence_bits)
timestamp = passed_timestamp + self._start_timestamp
return {
"timestamp": timestamp,
"datetime": datetime.fromtimestamp(timestamp / 1000, tz=UTC),
"worker_id": worker_id + 1,
"sequence": sequence,
"id": id,
}
模块
loguru 拦截 Uvicorn 和 FastAPI 日志
import datetime
import logging
import sys
from typing import TYPE_CHECKING
from loguru import logger
if TYPE_CHECKING:
from loguru import Record
logger.remove()
# ref: https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
class InterceptHandler(logging.Handler):
@logger.catch(default=True, onerror=lambda _: sys.exit(1))
def emit(self, record: logging.LogRecord) -> None:
level: str | int
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# Find caller from where originated the logged message.
frame, depth = logging.currentframe(), 0
while frame and (
depth == 0 or frame.f_code.co_filename == logging.__file__
):
frame = frame.f_back
depth += 1
logger.opt(
depth=depth,
exception=record.exc_info,
).log(level, record.getMessage())
# 替换根日志记录器的所有日志级别的 Handler
logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True)
# https://github.com/encode/uvicorn/issues/562
# 下面者几个因为不是使用根日志记录器的日志, 所以需要手动拦截
for log_name in ["uvicorn", "uvicorn.access", "fastapi"]:
_logger = logging.getLogger(log_name)
_logger.handlers = [InterceptHandler()]
FORMAT = (
"<level>{level: <8}</level> | "
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<cyan>{name}</cyan>:"
"<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
"<level>{message}</level>"
)
logger.add(sys.stdout, level=settings.LOG_LEVEL, format=FORMAT)
FastAPI
依赖:获取当前请求 ID
from typing import Annotated
from fastapi import Depends, Request
IP_HEADERS = [
"X-Forwarded-For", # 最常见的, 用于传递客户端 IP
"X-Real-IP", # 一些代理服务器使用, 如 Nginx
"Proxy-Client-IP", # 一些代理服务器或应用使用
"WL-Proxy-Client-IP", # WebLogic 服务器代理使用
"HTTP_CLIENT_IP", # 某些环境下的自定义头
"HTTP_X_FORWARDED_FOR", # 某些环境下的自定义头
"CF-Connecting-IP", # Cloudflare 使用
"True-Client-IP", # Akamai 使用
"X-Cluster-Client-IP", # 某些负载均衡器使用
"Fastly-Client-IP", # Fastly CDN 使用
"Forwarded", # 标准化头, RFC 7239 定义
]
def request_ip(request: Request):
headers = request.headers
for i in IP_HEADERS:
result = headers.get(i)
if result:
return result
if request.client:
return request.client.host
return None
RequestIP = Annotated[str | None, Depends(request_ip)]
Factory boy 配置 SQLAlchemy Async
通过 aiosqlite 和 asyncpg 测试,依赖 factory-boy>=3.3.0 & sqlalclehmy>=2
from typing import Generic, TypeVar
from factory import Factory
from factory.alchemy import (
SESSION_PERSISTENCE_COMMIT,
SESSION_PERSISTENCE_FLUSH,
VALID_SESSION_PERSISTENCE_TYPES,
)
from factory.base import FactoryOptions, OptionDefault
from factory.errors import FactoryError
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
_MT = TypeVar('_MT')
class SQLAlchemyAsyncOptions(FactoryOptions):
def _check_sqlalchemy_async_session_persistence(self, meta, value):
if value not in VALID_SESSION_PERSISTENCE_TYPES:
msg = (
f'{meta}.sqlalchemy_async_session_persistence must be one of '
f'${VALID_SESSION_PERSISTENCE_TYPES}, got ${value}'
)
raise TypeError(msg)
@staticmethod
def _check_has_sqlalchemy_async_session_set(meta, value):
if value and hasattr(meta, 'sqlalchemy_async_session'):
raise RuntimeError(
'Provide either a sqlalchemy_async_session '
'or a sqlalchemy_async_session_factory, not both'
)
def _build_default_options(self):
return super()._build_default_options() + [
OptionDefault('sqlalchemy_get_or_create', (), inherit=True),
OptionDefault('sqlalchemy_async_session', None, inherit=True),
OptionDefault(
'sqlalchemy_async_session_factory',
None,
inherit=True,
checker=self._check_has_sqlalchemy_async_session_set,
),
OptionDefault(
'sqlalchemy_async_session_persistence',
None,
inherit=True,
checker=self._check_sqlalchemy_async_session_persistence,
),
]
class SQLAlchemyAsyncModelFactory(Generic[_MT], Factory):
"""Factory for SQLAlchemy async models."""
_options_class = SQLAlchemyAsyncOptions
_original_params = None
class Meta: # type: ignore
abstract = True
@classmethod
def _generate(cls: type['SQLAlchemyAsyncModelFactory'], strategy, params):
# Original params are used in _get_or_create if it cannot build an
# object initially due to an IntegrityError being raised
cls._original_params = params
return super()._generate(strategy, params)
@classmethod
async def _save(
cls: type['SQLAlchemyAsyncModelFactory'],
model_class: type[_MT],
session: AsyncSession,
args,
kwargs,
):
session_persistence = cls._meta.sqlalchemy_async_session_persistence # type: ignore
obj = model_class(*args, **kwargs)
session.add(obj)
if session_persistence == SESSION_PERSISTENCE_FLUSH:
await session.flush()
elif session_persistence == SESSION_PERSISTENCE_COMMIT:
await session.commit()
return obj
@classmethod
async def _get_or_create(
cls: type['SQLAlchemyAsyncModelFactory'],
model_class: type[_MT],
session: AsyncSession,
args,
kwargs,
) -> _MT:
key_fields = {}
for f in cls._meta.sqlalchemy_get_or_create: # type: ignore
if f not in kwargs:
raise FactoryError(
f'sqlalchemy_get_or_create - Unable to find '
f"initialization value for '{f}' in factory {cls.__name__}"
)
key_fields[f] = kwargs.pop(f)
stmt = select(model_class).filter_by(*args, **key_fields)
obj = await session.scalar(stmt)
if not obj:
try:
obj = await cls._save(model_class, session, args, {**key_fields, **kwargs})
except IntegrityError as e:
await session.rollback()
if cls._original_params is None:
raise e
get_or_create_params = {
lookup: value
for lookup, value in cls._original_params.items()
if lookup in cls._meta.sqlalchemy_get_or_create # type: ignore
}
if get_or_create_params:
stmt = select(model_class).filter_by(**get_or_create_params)
try:
exec_result = await session.execute(stmt)
obj = exec_result.scalar_one()
except NoResultFound:
# Original params are not a valid lookup and triggered
# a create(), that resulted in an IntegrityError.
raise e # noqa: B904
else:
raise e
return obj
@classmethod
async def _create(
cls: type['SQLAlchemyAsyncModelFactory'],
model_class: type[_MT],
*args,
**kwargs,
) -> _MT:
"""Create an instance of the model, and save it to the database."""
async_session_factory: async_sessionmaker[AsyncSession] | None = None
async_session_factory = cls._meta.sqlalchemy_async_session_factory # type: ignore
if async_session_factory:
cls._meta.sqlalchemy_async_session = async_session_factory() # type: ignore
session = cls._meta.sqlalchemy_async_session # type: ignore
if session is None:
raise RuntimeError('No session provided.')
if cls._meta.sqlalchemy_get_or_create: # type: ignore
return await cls._get_or_create(model_class, session, args, kwargs)
return await cls._save(model_class, session, args, kwargs)
@classmethod
async def create_batch(cls, size, **kwargs): # type: ignore
return [await cls.create(**kwargs) for _ in range(size)]
测试用例
from datetime import UTC, datetime, timezone
import factory
from factory import fuzzy
from faker import Faker
from src.database import database_session_maker, models
from tests.utils import SQLAlchemyAsyncModelFactory
fake = Faker()
class ModelFactory(SQLAlchemyAsyncModelFactory):
class Meta:
abstract = True
sqlalchemy_async_session_factory = database_session_maker
sqlalchemy_async_session_persistence = 'commit'
id = factory.LazyFunction(fake.uuid4)
class TimeStampMixinFactory:
created_at = fuzzy.FuzzyDateTime(datetime(2020, 1, 1, tzinfo=UTC))
updated_at = fuzzy.FuzzyDateTime(datetime(2020, 1, 1, tzinfo=UTC))
class WalletFactory(TimeStampMixinFactory, ModelFactory):
class Meta:
model = models.Wallets
address = fuzzy.FuzzyText(length=34)
currency = factory.LazyFunction(fake.cryptocurrency_code)