• 亮色
  • 深色
  • 自动
  • RSS 订阅

    Python 常用工具

    2025-04-28

    函数

    路径

    列出指定目录及所有子目录下的文件夹

    import os
    
    videos_path = 'D:\Videos'
    dirs = []
    
    def joindir(path: str):
        os.chdir(path)
        for p in [os.path.join(path, dir) for dir in os.listdir() if os.path.isdir(os.path.join(path, dir))]:
            dirs.append(p)
            joindir(p)
    
    joindir(videos_path)
    print(dirs)
    

    列出指定目录及所有子目录下的文件

    import os
    
    videos_path = 'D:\Videos'
    
    def file(filepath):
        print(filepath)
    
    def joindir(path: str):
        os.chdir(path)
        for p in os.listdir():
            joinpath = os.path.join(path, p)
            if os.path.isdir(joinpath):
                joindir(joinpath)
            elif os.path.isfile(joinpath):
                file(joinpath)
            else:
                raise FileNotFoundError(joinpath)
    
    joindir(videos_path)
    

    单例模式

    class Singleton:
        def __new__(cls):
            if not hasattr(cls, "_singleton"):
                setattr(cls, "_singleton", super().__new__(cls))
            return getattr(cls, "_singleton")
    
    

    雪花算法 ID 生成器

    import asyncio
    import threading
    from datetime import UTC, datetime
    from logging import getLogger
    from time import sleep
    
    logger = getLogger()
    
    class SnowflakeID:
        @staticmethod
        def __current_timestamp():
            """获取毫秒时间戳"""
            return int(datetime.now(UTC).timestamp() * 1000)
    
        def __init__(
            self,
            worker_id: int,
            worker_id_bits: int = 10,
            sequence_bits: int = 12,
            start_timestamp: int = int(
                datetime(2024, 1, 1, tzinfo=UTC).timestamp() * 1e3
            ),
        ) -> None:
            message = f"Created Snowflake ID generator with worker ID {worker_id}."
            logger.info(message)
    
            self._worker_id_bits = worker_id_bits
            self._max_worker_id = 2**worker_id_bits - 1
    
            if self.max_worker_id < worker_id:
                msg = "Exceeds maximum value"
                raise ValueError(msg)
    
            if worker_id < 1:
                msg = "Below minimum value"
                raise ValueError(msg)
    
            self._worker_id = worker_id - 1
    
            # 同一毫秒内生成多个 ID 的序号
            self._sequence_bits = sequence_bits
            # 同一毫秒内最多生成多个 ID的数量
            self._max_millisecond_count = 2**sequence_bits - 1
    
            self._start_timestamp = start_timestamp
    
            self._last_timestamp = -1  # 上次生成 ID 的时间戳
            self._times = 0
            self._lock = threading.Lock()
            self._alock = asyncio.Lock()
    
        @property
        def worker_id_bits(self) -> int:
            return self._worker_id_bits
    
        @property
        def max_worker_id(self) -> int:
            return self._max_worker_id + 1
    
        @property
        def sequence_bits(self) -> int:
            return self._sequence_bits
    
        @property
        def max_millisecond_count(self) -> int:
            return self._max_millisecond_count + 1
    
        @property
        def start_timestamp(self) -> int:
            return self._start_timestamp
    
        def __get(self, passed_timestamp: int):
            result = passed_timestamp << (
                self._worker_id_bits + self._sequence_bits
            )
    
            result |= self._worker_id << self._sequence_bits
    
            result |= self._times
    
            return result
    
        def gen(self) -> int:
            with self._lock:
                assert self._worker_id <= self.max_worker_id
    
                timestamp = self.__current_timestamp()
    
                # 如果时间回拨, 等待修正时间
                if timestamp < self._last_timestamp:
                    sleep(self._last_timestamp - timestamp)
                    self._times = 0
    
                # 同一毫秒内生成
                if timestamp == self._last_timestamp:
                    self._times = (self._times + 1) & self._max_millisecond_count
                    # 溢出, 阻塞到下一秒
                    if self._times == 0:
                        sleep(0.001)
                        timestamp = self.__current_timestamp()
                # 非同一毫秒
                else:
                    self._times = 0
    
                self._last_timestamp = timestamp
    
                return self.__get(timestamp - self._start_timestamp)
    
        async def agen(self):
            async with self._alock:
                worker_id = self._worker_id
                assert worker_id <= self.max_worker_id
    
                timestamp = self.__current_timestamp()
    
                if timestamp < self._last_timestamp:
                    await asyncio.sleep(self._last_timestamp - timestamp)
                    self._times = 0
    
                if timestamp == self._last_timestamp:
                    self._times = (self._times + 1) & self._max_millisecond_count
                    if self._times == 0:
                        await asyncio.sleep(0.001)
                        timestamp = self.__current_timestamp()
                else:
                    self._times = 0
    
                self._last_timestamp = timestamp
    
                return self.__get(timestamp - self._start_timestamp)
    
        def decode(self, id: int):  # noqa: A002
            sequence_mask = (1 << self._sequence_bits) - 1
            worker_id_mask = (1 << self._worker_id_bits) - 1
    
            sequence = id & sequence_mask
            worker_id = (id >> self._sequence_bits) & worker_id_mask
            passed_timestamp = id >> (self._worker_id_bits + self._sequence_bits)
            timestamp = passed_timestamp + self._start_timestamp
    
            return {
                "timestamp": timestamp,
                "datetime": datetime.fromtimestamp(timestamp / 1000, tz=UTC),
                "worker_id": worker_id + 1,
                "sequence": sequence,
                "id": id,
            }
    
    

    模块

    loguru 拦截 Uvicorn 和 FastAPI 日志

    import datetime
    import logging
    import sys
    from typing import TYPE_CHECKING
    
    from loguru import logger
    
    if TYPE_CHECKING:
        from loguru import Record
    
    logger.remove()
    
    # ref: https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
    class InterceptHandler(logging.Handler):
        @logger.catch(default=True, onerror=lambda _: sys.exit(1))
        def emit(self, record: logging.LogRecord) -> None:
            level: str | int
            try:
                level = logger.level(record.levelname).name
            except ValueError:
                level = record.levelno
    
            # Find caller from where originated the logged message.
            frame, depth = logging.currentframe(), 0
            while frame and (
                depth == 0 or frame.f_code.co_filename == logging.__file__
            ):
                frame = frame.f_back
                depth += 1
    
            logger.opt(
                depth=depth,
                exception=record.exc_info,
            ).log(level, record.getMessage())
    
    # 替换根日志记录器的所有日志级别的 Handler
    logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True)
    
    # https://github.com/encode/uvicorn/issues/562
    # 下面者几个因为不是使用根日志记录器的日志, 所以需要手动拦截
    for log_name in ["uvicorn", "uvicorn.access", "fastapi"]:
        _logger = logging.getLogger(log_name)
        _logger.handlers = [InterceptHandler()]
    
    FORMAT = (
        "<level>{level: <8}</level> | "
        "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
        "<cyan>{name}</cyan>:"
        "<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
        "<level>{message}</level>"
    )
    
    logger.add(sys.stdout, level=settings.LOG_LEVEL, format=FORMAT)
    

    FastAPI

    依赖:获取当前请求 ID

    from typing import Annotated
    
    from fastapi import Depends, Request
    
    IP_HEADERS = [
        "X-Forwarded-For",  # 最常见的, 用于传递客户端 IP
        "X-Real-IP",  # 一些代理服务器使用, 如 Nginx
        "Proxy-Client-IP",  # 一些代理服务器或应用使用
        "WL-Proxy-Client-IP",  # WebLogic 服务器代理使用
        "HTTP_CLIENT_IP",  # 某些环境下的自定义头
        "HTTP_X_FORWARDED_FOR",  # 某些环境下的自定义头
        "CF-Connecting-IP",  # Cloudflare 使用
        "True-Client-IP",  # Akamai 使用
        "X-Cluster-Client-IP",  # 某些负载均衡器使用
        "Fastly-Client-IP",  # Fastly CDN 使用
        "Forwarded",  # 标准化头, RFC 7239 定义
    ]
    
    def request_ip(request: Request):
        headers = request.headers
    
        for i in IP_HEADERS:
            result = headers.get(i)
            if result:
                return result
    
        if request.client:
            return request.client.host
    
        return None
    
    RequestIP = Annotated[str | None, Depends(request_ip)]
    
    

    Factory boy 配置 SQLAlchemy Async

    通过 aiosqliteasyncpg 测试,依赖 factory-boy>=3.3.0 & sqlalclehmy>=2

    from typing import Generic, TypeVar
    
    from factory import Factory
    from factory.alchemy import (
        SESSION_PERSISTENCE_COMMIT,
        SESSION_PERSISTENCE_FLUSH,
        VALID_SESSION_PERSISTENCE_TYPES,
    )
    from factory.base import FactoryOptions, OptionDefault
    from factory.errors import FactoryError
    from sqlalchemy import select
    from sqlalchemy.exc import IntegrityError, NoResultFound
    from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
    
    _MT = TypeVar('_MT')
    
    class SQLAlchemyAsyncOptions(FactoryOptions):
        def _check_sqlalchemy_async_session_persistence(self, meta, value):
            if value not in VALID_SESSION_PERSISTENCE_TYPES:
                msg = (
                    f'{meta}.sqlalchemy_async_session_persistence must be one of '
                    f'${VALID_SESSION_PERSISTENCE_TYPES}, got ${value}'
                )
                raise TypeError(msg)
    
        @staticmethod
        def _check_has_sqlalchemy_async_session_set(meta, value):
            if value and hasattr(meta, 'sqlalchemy_async_session'):
                raise RuntimeError(
                    'Provide either a sqlalchemy_async_session '
                    'or a sqlalchemy_async_session_factory,  not both'
                )
    
        def _build_default_options(self):
            return super()._build_default_options() + [
                OptionDefault('sqlalchemy_get_or_create', (), inherit=True),
                OptionDefault('sqlalchemy_async_session', None, inherit=True),
                OptionDefault(
                    'sqlalchemy_async_session_factory',
                    None,
                    inherit=True,
                    checker=self._check_has_sqlalchemy_async_session_set,
                ),
                OptionDefault(
                    'sqlalchemy_async_session_persistence',
                    None,
                    inherit=True,
                    checker=self._check_sqlalchemy_async_session_persistence,
                ),
            ]
    
    class SQLAlchemyAsyncModelFactory(Generic[_MT], Factory):
        """Factory for SQLAlchemy async models."""
    
        _options_class = SQLAlchemyAsyncOptions
        _original_params = None
    
        class Meta:  # type: ignore
            abstract = True
    
        @classmethod
        def _generate(cls: type['SQLAlchemyAsyncModelFactory'], strategy, params):
            # Original params are used in _get_or_create if it cannot build an
            # object initially due to an IntegrityError being raised
            cls._original_params = params
            return super()._generate(strategy, params)
    
        @classmethod
        async def _save(
            cls: type['SQLAlchemyAsyncModelFactory'],
            model_class: type[_MT],
            session: AsyncSession,
            args,
            kwargs,
        ):
            session_persistence = cls._meta.sqlalchemy_async_session_persistence  # type: ignore
    
            obj = model_class(*args, **kwargs)
            session.add(obj)
            if session_persistence == SESSION_PERSISTENCE_FLUSH:
                await session.flush()
            elif session_persistence == SESSION_PERSISTENCE_COMMIT:
                await session.commit()
            return obj
    
        @classmethod
        async def _get_or_create(
            cls: type['SQLAlchemyAsyncModelFactory'],
            model_class: type[_MT],
            session: AsyncSession,
            args,
            kwargs,
        ) -> _MT:
            key_fields = {}
            for f in cls._meta.sqlalchemy_get_or_create:  # type: ignore
                if f not in kwargs:
                    raise FactoryError(
                        f'sqlalchemy_get_or_create - Unable to find '
                        f"initialization value for '{f}' in factory {cls.__name__}"
                    )
                key_fields[f] = kwargs.pop(f)
    
            stmt = select(model_class).filter_by(*args, **key_fields)
            obj = await session.scalar(stmt)
    
            if not obj:
                try:
                    obj = await cls._save(model_class, session, args, {**key_fields, **kwargs})
                except IntegrityError as e:
                    await session.rollback()
    
                    if cls._original_params is None:
                        raise e
    
                    get_or_create_params = {
                        lookup: value
                        for lookup, value in cls._original_params.items()
                        if lookup in cls._meta.sqlalchemy_get_or_create  # type: ignore
                    }
                    if get_or_create_params:
                        stmt = select(model_class).filter_by(**get_or_create_params)
    
                        try:
                            exec_result = await session.execute(stmt)
                            obj = exec_result.scalar_one()
                        except NoResultFound:
                            # Original params are not a valid lookup and triggered
                            # a create(), that resulted in an IntegrityError.
                            raise e  # noqa: B904
                    else:
                        raise e
    
            return obj
    
        @classmethod
        async def _create(
            cls: type['SQLAlchemyAsyncModelFactory'],
            model_class: type[_MT],
            *args,
            **kwargs,
        ) -> _MT:
            """Create an instance of the model, and save it to the database."""
            async_session_factory: async_sessionmaker[AsyncSession] | None = None
            async_session_factory = cls._meta.sqlalchemy_async_session_factory  # type: ignore
    
            if async_session_factory:
                cls._meta.sqlalchemy_async_session = async_session_factory()  # type: ignore
    
            session = cls._meta.sqlalchemy_async_session  # type: ignore
    
            if session is None:
                raise RuntimeError('No session provided.')
            if cls._meta.sqlalchemy_get_or_create:  # type: ignore
                return await cls._get_or_create(model_class, session, args, kwargs)
            return await cls._save(model_class, session, args, kwargs)
    
        @classmethod
        async def create_batch(cls, size, **kwargs):  # type: ignore
            return [await cls.create(**kwargs) for _ in range(size)]
    
    

    测试用例

    from datetime import UTC, datetime, timezone
    
    import factory
    from factory import fuzzy
    from faker import Faker
    
    from src.database import database_session_maker, models
    from tests.utils import SQLAlchemyAsyncModelFactory
    
    fake = Faker()
    
    class ModelFactory(SQLAlchemyAsyncModelFactory):
        class Meta:
            abstract = True
            sqlalchemy_async_session_factory = database_session_maker
            sqlalchemy_async_session_persistence = 'commit'
    
        id = factory.LazyFunction(fake.uuid4)
    
    class TimeStampMixinFactory:
        created_at = fuzzy.FuzzyDateTime(datetime(2020, 1, 1, tzinfo=UTC))
        updated_at = fuzzy.FuzzyDateTime(datetime(2020, 1, 1, tzinfo=UTC))
    
    class WalletFactory(TimeStampMixinFactory, ModelFactory):
        class Meta:
            model = models.Wallets
    
        address = fuzzy.FuzzyText(length=34)
        currency = factory.LazyFunction(fake.cryptocurrency_code)