- 新增图像生成接口,支持试用、积分和自定义API Key模式 - 实现生成图片结果异步上传至MinIO存储,带重试机制 - 优化积分预扣除和异常退还逻辑,保障用户积分准确 - 添加获取生成历史记录接口,支持时间范围和分页 - 提供本地字典配置接口,支持模型、比例、提示模板和尺寸 - 实现图片批量上传接口,支持S3兼容对象存储 feat(admin): 增加管理员角色管理与权限分配接口 - 实现角色列表查询、角色创建、更新及删除功能 - 增加权限列表查询接口 - 实现用户角色分配接口,便于统一管理用户权限 - 增加系统字典增删查改接口,支持分类过滤和排序 - 权限控制全面覆盖管理接口,保证安全访问 feat(auth): 完善用户登录注册及权限相关接口与页面 - 实现手机号验证码发送及校验功能,保障注册安全 - 支持手机号注册、登录及退出接口,集成日志记录 - 增加修改密码功能,验证原密码后更新 - 提供动态导航菜单接口,基于权限展示不同菜单 - 实现管理界面路由及日志、角色、字典管理页面访问权限控制 - 添加系统日志查询接口,支持关键词和等级筛选 feat(app): 初始化Flask应用并配置蓝图与数据库 - 创建应用程序工厂,加载配置,初始化数据库和Redis客户端 - 注册认证、API及管理员蓝图,整合路由 - 根路由渲染主页模板 - 应用上下文中自动创建数据库表,保证运行环境准备完毕 feat(database): 提供数据库创建与迁移支持脚本 - 新增数据库创建脚本,支持自动检测是否已存在 - 添加数据库表初始化脚本,支持创建和删除所有表 - 实现RBAC权限初始化,包含基础权限和角色创建 - 新增字段手动修复脚本,添加用户API Key和积分字段 - 强制迁移脚本支持清理连接和修复表结构,初始化默认数据及角色分配 feat(config): 新增系统配置参数 - 配置数据库、Redis、Session和MinIO相关参数 - 添加AI接口地址及试用Key配置 - 集成阿里云短信服务配置及开发模式相关参数 feat(extensions): 初始化数据库、Redis和MinIO客户端 - 创建全局SQLAlchemy数据库实例和Redis客户端 - 配置基于boto3的MinIO兼容S3客户端 chore(logs): 添加示例系统日志文件 - 记录用户请求、验证码发送成功与失败的日志信息
777 lines
24 KiB
Python
777 lines
24 KiB
Python
# testing/suite/test_dialect.py
|
|
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
# mypy: ignore-errors
|
|
|
|
|
|
import importlib
|
|
|
|
from . import testing
|
|
from .. import assert_raises
|
|
from .. import config
|
|
from .. import engines
|
|
from .. import eq_
|
|
from .. import fixtures
|
|
from .. import is_not_none
|
|
from .. import is_true
|
|
from .. import mock
|
|
from .. import ne_
|
|
from .. import provide_metadata
|
|
from ..assertions import expect_raises
|
|
from ..assertions import expect_raises_message
|
|
from ..config import requirements
|
|
from ..provision import set_default_schema_on_connection
|
|
from ..schema import Column
|
|
from ..schema import Table
|
|
from ... import bindparam
|
|
from ... import dialects
|
|
from ... import event
|
|
from ... import exc
|
|
from ... import Integer
|
|
from ... import literal_column
|
|
from ... import select
|
|
from ... import String
|
|
from ...sql.compiler import Compiled
|
|
from ...util import inspect_getfullargspec
|
|
|
|
|
|
class PingTest(fixtures.TestBase):
|
|
__backend__ = True
|
|
|
|
def test_do_ping(self):
|
|
with testing.db.connect() as conn:
|
|
is_true(
|
|
testing.db.dialect.do_ping(conn.connection.dbapi_connection)
|
|
)
|
|
|
|
|
|
class ArgSignatureTest(fixtures.TestBase):
|
|
"""test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have
|
|
``**kw``, for #8988.
|
|
|
|
This test uses runtime code inspection. Does not need to be a
|
|
``__backend__`` test as it only needs to run once provided all target
|
|
dialects have been imported.
|
|
|
|
For third party dialects, the suite would be run with that third
|
|
party as a "--dburi", which means its compiler classes will have been
|
|
imported by the time this test runs.
|
|
|
|
"""
|
|
|
|
def _all_subclasses(): # type: ignore # noqa
|
|
for d in dialects.__all__:
|
|
if not d.startswith("_"):
|
|
importlib.import_module("sqlalchemy.dialects.%s" % d)
|
|
|
|
stack = [Compiled]
|
|
|
|
while stack:
|
|
cls = stack.pop(0)
|
|
stack.extend(cls.__subclasses__())
|
|
yield cls
|
|
|
|
@testing.fixture(params=list(_all_subclasses()))
|
|
def all_subclasses(self, request):
|
|
yield request.param
|
|
|
|
def test_all_visit_methods_accept_kw(self, all_subclasses):
|
|
cls = all_subclasses
|
|
|
|
for k in cls.__dict__:
|
|
if k.startswith("visit_"):
|
|
meth = getattr(cls, k)
|
|
|
|
insp = inspect_getfullargspec(meth)
|
|
is_not_none(
|
|
insp.varkw,
|
|
f"Compiler visit method {cls.__name__}.{k}() does "
|
|
"not accommodate for **kw in its argument signature",
|
|
)
|
|
|
|
|
|
class ExceptionTest(fixtures.TablesTest):
|
|
"""Test basic exception wrapping.
|
|
|
|
DBAPIs vary a lot in exception behavior so to actually anticipate
|
|
specific exceptions from real round trips, we need to be conservative.
|
|
|
|
"""
|
|
|
|
run_deletes = "each"
|
|
|
|
__backend__ = True
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"manual_pk",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True, autoincrement=False),
|
|
Column("data", String(50)),
|
|
)
|
|
|
|
@requirements.duplicate_key_raises_integrity_error
|
|
def test_integrity_error(self):
|
|
with config.db.connect() as conn:
|
|
trans = conn.begin()
|
|
conn.execute(
|
|
self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
|
|
)
|
|
|
|
assert_raises(
|
|
exc.IntegrityError,
|
|
conn.execute,
|
|
self.tables.manual_pk.insert(),
|
|
{"id": 1, "data": "d1"},
|
|
)
|
|
|
|
trans.rollback()
|
|
|
|
def test_exception_with_non_ascii(self):
|
|
with config.db.connect() as conn:
|
|
try:
|
|
# try to create an error message that likely has non-ascii
|
|
# characters in the DBAPI's message string. unfortunately
|
|
# there's no way to make this happen with some drivers like
|
|
# mysqlclient, pymysql. this at least does produce a non-
|
|
# ascii error message for cx_oracle, psycopg2
|
|
conn.execute(select(literal_column("méil")))
|
|
assert False
|
|
except exc.DBAPIError as err:
|
|
err_str = str(err)
|
|
|
|
assert str(err.orig) in str(err)
|
|
|
|
assert isinstance(err_str, str)
|
|
|
|
|
|
class IsolationLevelTest(fixtures.TestBase):
|
|
__backend__ = True
|
|
|
|
__requires__ = ("isolation_level",)
|
|
|
|
def _get_non_default_isolation_level(self):
|
|
levels = requirements.get_isolation_levels(config)
|
|
|
|
default = levels["default"]
|
|
supported = levels["supported"]
|
|
|
|
s = set(supported).difference(["AUTOCOMMIT", default])
|
|
if s:
|
|
return s.pop()
|
|
else:
|
|
config.skip_test("no non-default isolation level available")
|
|
|
|
def test_default_isolation_level(self):
|
|
eq_(
|
|
config.db.dialect.default_isolation_level,
|
|
requirements.get_isolation_levels(config)["default"],
|
|
)
|
|
|
|
def test_non_default_isolation_level(self):
|
|
non_default = self._get_non_default_isolation_level()
|
|
|
|
with config.db.connect() as conn:
|
|
existing = conn.get_isolation_level()
|
|
|
|
ne_(existing, non_default)
|
|
|
|
conn.execution_options(isolation_level=non_default)
|
|
|
|
eq_(conn.get_isolation_level(), non_default)
|
|
|
|
conn.dialect.reset_isolation_level(
|
|
conn.connection.dbapi_connection
|
|
)
|
|
|
|
eq_(conn.get_isolation_level(), existing)
|
|
|
|
def test_all_levels(self):
|
|
levels = requirements.get_isolation_levels(config)
|
|
|
|
all_levels = levels["supported"]
|
|
|
|
for level in set(all_levels).difference(["AUTOCOMMIT"]):
|
|
with config.db.connect() as conn:
|
|
conn.execution_options(isolation_level=level)
|
|
|
|
eq_(conn.get_isolation_level(), level)
|
|
|
|
trans = conn.begin()
|
|
trans.rollback()
|
|
|
|
eq_(conn.get_isolation_level(), level)
|
|
|
|
with config.db.connect() as conn:
|
|
eq_(
|
|
conn.get_isolation_level(),
|
|
levels["default"],
|
|
)
|
|
|
|
@testing.requires.get_isolation_level_values
|
|
def test_invalid_level_execution_option(self, connection_no_trans):
|
|
"""test for the new get_isolation_level_values() method"""
|
|
|
|
connection = connection_no_trans
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
"Invalid value '%s' for isolation_level. "
|
|
"Valid isolation levels for '%s' are %s"
|
|
% (
|
|
"FOO",
|
|
connection.dialect.name,
|
|
", ".join(
|
|
requirements.get_isolation_levels(config)["supported"]
|
|
),
|
|
),
|
|
):
|
|
connection.execution_options(isolation_level="FOO")
|
|
|
|
@testing.requires.get_isolation_level_values
|
|
@testing.requires.dialect_level_isolation_level_param
|
|
def test_invalid_level_engine_param(self, testing_engine):
|
|
"""test for the new get_isolation_level_values() method
|
|
and support for the dialect-level 'isolation_level' parameter.
|
|
|
|
"""
|
|
|
|
eng = testing_engine(options=dict(isolation_level="FOO"))
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
"Invalid value '%s' for isolation_level. "
|
|
"Valid isolation levels for '%s' are %s"
|
|
% (
|
|
"FOO",
|
|
eng.dialect.name,
|
|
", ".join(
|
|
requirements.get_isolation_levels(config)["supported"]
|
|
),
|
|
),
|
|
):
|
|
eng.connect()
|
|
|
|
@testing.requires.independent_readonly_connections
|
|
def test_dialect_user_setting_is_restored(self, testing_engine):
|
|
levels = requirements.get_isolation_levels(config)
|
|
default = levels["default"]
|
|
supported = (
|
|
sorted(
|
|
set(levels["supported"]).difference([default, "AUTOCOMMIT"])
|
|
)
|
|
)[0]
|
|
|
|
e = testing_engine(options={"isolation_level": supported})
|
|
|
|
with e.connect() as conn:
|
|
eq_(conn.get_isolation_level(), supported)
|
|
|
|
with e.connect() as conn:
|
|
conn.execution_options(isolation_level=default)
|
|
eq_(conn.get_isolation_level(), default)
|
|
|
|
with e.connect() as conn:
|
|
eq_(conn.get_isolation_level(), supported)
|
|
|
|
|
|
class AutocommitIsolationTest(fixtures.TablesTest):
|
|
run_deletes = "each"
|
|
|
|
__requires__ = ("autocommit",)
|
|
|
|
__backend__ = True
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"some_table",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True, autoincrement=False),
|
|
Column("data", String(50)),
|
|
test_needs_acid=True,
|
|
)
|
|
|
|
def _test_conn_autocommits(self, conn, autocommit, ensure_table=False):
|
|
if ensure_table:
|
|
self.tables.some_table.create(conn, checkfirst=True)
|
|
conn.commit()
|
|
|
|
trans = conn.begin()
|
|
conn.execute(
|
|
self.tables.some_table.insert(), {"id": 1, "data": "some data"}
|
|
)
|
|
trans.rollback()
|
|
|
|
eq_(
|
|
conn.scalar(select(self.tables.some_table.c.id)),
|
|
1 if autocommit else None,
|
|
)
|
|
conn.rollback()
|
|
|
|
with conn.begin():
|
|
conn.execute(self.tables.some_table.delete())
|
|
|
|
def test_autocommit_on(self, connection_no_trans):
|
|
conn = connection_no_trans
|
|
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
|
|
self._test_conn_autocommits(c2, True)
|
|
|
|
c2.dialect.reset_isolation_level(c2.connection.dbapi_connection)
|
|
|
|
self._test_conn_autocommits(conn, False)
|
|
|
|
def test_autocommit_off(self, connection_no_trans):
|
|
conn = connection_no_trans
|
|
self._test_conn_autocommits(conn, False)
|
|
|
|
def test_turn_autocommit_off_via_default_iso_level(
|
|
self, connection_no_trans
|
|
):
|
|
conn = connection_no_trans
|
|
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
|
self._test_conn_autocommits(conn, True)
|
|
|
|
conn.execution_options(
|
|
isolation_level=requirements.get_isolation_levels(config)[
|
|
"default"
|
|
]
|
|
)
|
|
self._test_conn_autocommits(conn, False)
|
|
|
|
@testing.requires.skip_autocommit_rollback
|
|
@testing.variation("autocommit_setting", ["false", "engine", "option"])
|
|
@testing.variation("block_rollback", [True, False])
|
|
def test_autocommit_block(
|
|
self, testing_engine, autocommit_setting, block_rollback
|
|
):
|
|
kw = {}
|
|
if bool(block_rollback):
|
|
kw["skip_autocommit_rollback"] = True
|
|
if autocommit_setting.engine:
|
|
kw["isolation_level"] = "AUTOCOMMIT"
|
|
|
|
engine = testing_engine(options=kw)
|
|
|
|
conn = engine.connect()
|
|
if autocommit_setting.option:
|
|
conn.execution_options(isolation_level="AUTOCOMMIT")
|
|
self._test_conn_autocommits(
|
|
conn,
|
|
autocommit_setting.engine or autocommit_setting.option,
|
|
ensure_table=True,
|
|
)
|
|
with mock.patch.object(
|
|
conn.connection, "rollback", wraps=conn.connection.rollback
|
|
) as check_rollback:
|
|
conn.close()
|
|
if autocommit_setting.false or not block_rollback:
|
|
eq_(check_rollback.mock_calls, [mock.call()])
|
|
else:
|
|
eq_(check_rollback.mock_calls, [])
|
|
|
|
@testing.requires.independent_readonly_connections
|
|
@testing.variation("use_dialect_setting", [True, False])
|
|
def test_dialect_autocommit_is_restored(
|
|
self, testing_engine, use_dialect_setting
|
|
):
|
|
"""test #10147"""
|
|
|
|
if use_dialect_setting:
|
|
e = testing_engine(options={"isolation_level": "AUTOCOMMIT"})
|
|
else:
|
|
e = testing_engine().execution_options(
|
|
isolation_level="AUTOCOMMIT"
|
|
)
|
|
|
|
levels = requirements.get_isolation_levels(config)
|
|
|
|
default = levels["default"]
|
|
|
|
with e.connect() as conn:
|
|
self._test_conn_autocommits(conn, True)
|
|
|
|
with e.connect() as conn:
|
|
conn.execution_options(isolation_level=default)
|
|
self._test_conn_autocommits(conn, False)
|
|
|
|
with e.connect() as conn:
|
|
self._test_conn_autocommits(conn, True)
|
|
|
|
|
|
class EscapingTest(fixtures.TestBase):
|
|
@provide_metadata
|
|
def test_percent_sign_round_trip(self):
|
|
"""test that the DBAPI accommodates for escaped / nonescaped
|
|
percent signs in a way that matches the compiler
|
|
|
|
"""
|
|
m = self.metadata
|
|
t = Table("t", m, Column("data", String(50)))
|
|
t.create(config.db)
|
|
with config.db.begin() as conn:
|
|
conn.execute(t.insert(), dict(data="some % value"))
|
|
conn.execute(t.insert(), dict(data="some %% other value"))
|
|
|
|
eq_(
|
|
conn.scalar(
|
|
select(t.c.data).where(
|
|
t.c.data == literal_column("'some % value'")
|
|
)
|
|
),
|
|
"some % value",
|
|
)
|
|
|
|
eq_(
|
|
conn.scalar(
|
|
select(t.c.data).where(
|
|
t.c.data == literal_column("'some %% other value'")
|
|
)
|
|
),
|
|
"some %% other value",
|
|
)
|
|
|
|
|
|
class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
|
|
__backend__ = True
|
|
|
|
__requires__ = ("default_schema_name_switch",)
|
|
|
|
def test_control_case(self):
|
|
default_schema_name = config.db.dialect.default_schema_name
|
|
|
|
eng = engines.testing_engine()
|
|
with eng.connect():
|
|
pass
|
|
|
|
eq_(eng.dialect.default_schema_name, default_schema_name)
|
|
|
|
def test_wont_work_wo_insert(self):
|
|
default_schema_name = config.db.dialect.default_schema_name
|
|
|
|
eng = engines.testing_engine()
|
|
|
|
@event.listens_for(eng, "connect")
|
|
def on_connect(dbapi_connection, connection_record):
|
|
set_default_schema_on_connection(
|
|
config, dbapi_connection, config.test_schema
|
|
)
|
|
|
|
with eng.connect() as conn:
|
|
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
|
eq_(what_it_should_be, config.test_schema)
|
|
|
|
eq_(eng.dialect.default_schema_name, default_schema_name)
|
|
|
|
def test_schema_change_on_connect(self):
|
|
eng = engines.testing_engine()
|
|
|
|
@event.listens_for(eng, "connect", insert=True)
|
|
def on_connect(dbapi_connection, connection_record):
|
|
set_default_schema_on_connection(
|
|
config, dbapi_connection, config.test_schema
|
|
)
|
|
|
|
with eng.connect() as conn:
|
|
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
|
eq_(what_it_should_be, config.test_schema)
|
|
|
|
eq_(eng.dialect.default_schema_name, config.test_schema)
|
|
|
|
def test_schema_change_works_w_transactions(self):
|
|
eng = engines.testing_engine()
|
|
|
|
@event.listens_for(eng, "connect", insert=True)
|
|
def on_connect(dbapi_connection, *arg):
|
|
set_default_schema_on_connection(
|
|
config, dbapi_connection, config.test_schema
|
|
)
|
|
|
|
with eng.connect() as conn:
|
|
trans = conn.begin()
|
|
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
|
eq_(what_it_should_be, config.test_schema)
|
|
trans.rollback()
|
|
|
|
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
|
eq_(what_it_should_be, config.test_schema)
|
|
|
|
eq_(eng.dialect.default_schema_name, config.test_schema)
|
|
|
|
|
|
class FutureWeCanSetDefaultSchemaWEventsTest(
|
|
fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
|
|
):
|
|
pass
|
|
|
|
|
|
class DifficultParametersTest(fixtures.TestBase):
|
|
__backend__ = True
|
|
|
|
tough_parameters = testing.combinations(
|
|
("boring",),
|
|
("per cent",),
|
|
("per % cent",),
|
|
("%percent",),
|
|
("par(ens)",),
|
|
("percent%(ens)yah",),
|
|
("col:ons",),
|
|
("_starts_with_underscore",),
|
|
("dot.s",),
|
|
("more :: %colons%",),
|
|
("_name",),
|
|
("___name",),
|
|
("[BracketsAndCase]",),
|
|
("42numbers",),
|
|
("percent%signs",),
|
|
("has spaces",),
|
|
("/slashes/",),
|
|
("more/slashes",),
|
|
("q?marks",),
|
|
("1param",),
|
|
("1col:on",),
|
|
argnames="paramname",
|
|
)
|
|
|
|
@tough_parameters
|
|
@config.requirements.unusual_column_name_characters
|
|
def test_round_trip_same_named_column(
|
|
self, paramname, connection, metadata
|
|
):
|
|
name = paramname
|
|
|
|
t = Table(
|
|
"t",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True),
|
|
Column(name, String(50), nullable=False),
|
|
)
|
|
|
|
# table is created
|
|
t.create(connection)
|
|
|
|
# automatic param generated by insert
|
|
connection.execute(t.insert().values({"id": 1, name: "some name"}))
|
|
|
|
# automatic param generated by criteria, plus selecting the column
|
|
stmt = select(t.c[name]).where(t.c[name] == "some name")
|
|
|
|
eq_(connection.scalar(stmt), "some name")
|
|
|
|
# use the name in a param explicitly
|
|
stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
|
|
|
|
row = connection.execute(stmt, {name: "some name"}).first()
|
|
|
|
# name works as the key from cursor.description
|
|
eq_(row._mapping[name], "some name")
|
|
|
|
# use expanding IN
|
|
stmt = select(t.c[name]).where(
|
|
t.c[name].in_(["some name", "some other_name"])
|
|
)
|
|
|
|
connection.execute(stmt).first()
|
|
|
|
@testing.fixture
|
|
def multirow_fixture(self, metadata, connection):
|
|
mytable = Table(
|
|
"mytable",
|
|
metadata,
|
|
Column("myid", Integer),
|
|
Column("name", String(50)),
|
|
Column("desc", String(50)),
|
|
)
|
|
|
|
mytable.create(connection)
|
|
|
|
connection.execute(
|
|
mytable.insert(),
|
|
[
|
|
{"myid": 1, "name": "a", "desc": "a_desc"},
|
|
{"myid": 2, "name": "b", "desc": "b_desc"},
|
|
{"myid": 3, "name": "c", "desc": "c_desc"},
|
|
{"myid": 4, "name": "d", "desc": "d_desc"},
|
|
],
|
|
)
|
|
yield mytable
|
|
|
|
@tough_parameters
|
|
def test_standalone_bindparam_escape(
|
|
self, paramname, connection, multirow_fixture
|
|
):
|
|
tbl1 = multirow_fixture
|
|
stmt = select(tbl1.c.myid).where(
|
|
tbl1.c.name == bindparam(paramname, value="x")
|
|
)
|
|
res = connection.scalar(stmt, {paramname: "c"})
|
|
eq_(res, 3)
|
|
|
|
@tough_parameters
|
|
def test_standalone_bindparam_escape_expanding(
|
|
self, paramname, connection, multirow_fixture
|
|
):
|
|
tbl1 = multirow_fixture
|
|
stmt = (
|
|
select(tbl1.c.myid)
|
|
.where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"])))
|
|
.order_by(tbl1.c.myid)
|
|
)
|
|
|
|
res = connection.scalars(stmt, {paramname: ["d", "a"]}).all()
|
|
eq_(res, [1, 4])
|
|
|
|
|
|
class ReturningGuardsTest(fixtures.TablesTest):
|
|
"""test that the various 'returning' flags are set appropriately"""
|
|
|
|
__backend__ = True
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"t",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True, autoincrement=False),
|
|
Column("data", String(50)),
|
|
)
|
|
|
|
@testing.fixture
|
|
def run_stmt(self, connection):
|
|
t = self.tables.t
|
|
|
|
def go(stmt, executemany, id_param_name, expect_success):
|
|
stmt = stmt.returning(t.c.id)
|
|
|
|
if executemany:
|
|
if not expect_success:
|
|
# for RETURNING executemany(), we raise our own
|
|
# error as this is independent of general RETURNING
|
|
# support
|
|
with expect_raises_message(
|
|
exc.StatementError,
|
|
rf"Dialect {connection.dialect.name}\+"
|
|
f"{connection.dialect.driver} with "
|
|
f"current server capabilities does not support "
|
|
f".*RETURNING when executemany is used",
|
|
):
|
|
connection.execute(
|
|
stmt,
|
|
[
|
|
{id_param_name: 1, "data": "d1"},
|
|
{id_param_name: 2, "data": "d2"},
|
|
{id_param_name: 3, "data": "d3"},
|
|
],
|
|
)
|
|
else:
|
|
result = connection.execute(
|
|
stmt,
|
|
[
|
|
{id_param_name: 1, "data": "d1"},
|
|
{id_param_name: 2, "data": "d2"},
|
|
{id_param_name: 3, "data": "d3"},
|
|
],
|
|
)
|
|
eq_(result.all(), [(1,), (2,), (3,)])
|
|
else:
|
|
if not expect_success:
|
|
# for RETURNING execute(), we pass all the way to the DB
|
|
# and let it fail
|
|
with expect_raises(exc.DBAPIError):
|
|
connection.execute(
|
|
stmt, {id_param_name: 1, "data": "d1"}
|
|
)
|
|
else:
|
|
result = connection.execute(
|
|
stmt, {id_param_name: 1, "data": "d1"}
|
|
)
|
|
eq_(result.all(), [(1,)])
|
|
|
|
return go
|
|
|
|
def test_insert_single(self, connection, run_stmt):
|
|
t = self.tables.t
|
|
|
|
stmt = t.insert()
|
|
|
|
run_stmt(stmt, False, "id", connection.dialect.insert_returning)
|
|
|
|
def test_insert_many(self, connection, run_stmt):
|
|
t = self.tables.t
|
|
|
|
stmt = t.insert()
|
|
|
|
run_stmt(
|
|
stmt, True, "id", connection.dialect.insert_executemany_returning
|
|
)
|
|
|
|
def test_update_single(self, connection, run_stmt):
|
|
t = self.tables.t
|
|
|
|
connection.execute(
|
|
t.insert(),
|
|
[
|
|
{"id": 1, "data": "d1"},
|
|
{"id": 2, "data": "d2"},
|
|
{"id": 3, "data": "d3"},
|
|
],
|
|
)
|
|
|
|
stmt = t.update().where(t.c.id == bindparam("b_id"))
|
|
|
|
run_stmt(stmt, False, "b_id", connection.dialect.update_returning)
|
|
|
|
def test_update_many(self, connection, run_stmt):
|
|
t = self.tables.t
|
|
|
|
connection.execute(
|
|
t.insert(),
|
|
[
|
|
{"id": 1, "data": "d1"},
|
|
{"id": 2, "data": "d2"},
|
|
{"id": 3, "data": "d3"},
|
|
],
|
|
)
|
|
|
|
stmt = t.update().where(t.c.id == bindparam("b_id"))
|
|
|
|
run_stmt(
|
|
stmt, True, "b_id", connection.dialect.update_executemany_returning
|
|
)
|
|
|
|
def test_delete_single(self, connection, run_stmt):
|
|
t = self.tables.t
|
|
|
|
connection.execute(
|
|
t.insert(),
|
|
[
|
|
{"id": 1, "data": "d1"},
|
|
{"id": 2, "data": "d2"},
|
|
{"id": 3, "data": "d3"},
|
|
],
|
|
)
|
|
|
|
stmt = t.delete().where(t.c.id == bindparam("b_id"))
|
|
|
|
run_stmt(stmt, False, "b_id", connection.dialect.delete_returning)
|
|
|
|
def test_delete_many(self, connection, run_stmt):
|
|
t = self.tables.t
|
|
|
|
connection.execute(
|
|
t.insert(),
|
|
[
|
|
{"id": 1, "data": "d1"},
|
|
{"id": 2, "data": "d2"},
|
|
{"id": 3, "data": "d3"},
|
|
],
|
|
)
|
|
|
|
stmt = t.delete().where(t.c.id == bindparam("b_id"))
|
|
|
|
run_stmt(
|
|
stmt, True, "b_id", connection.dialect.delete_executemany_returning
|
|
)
|