153 lines
6.5 KiB
Python
153 lines
6.5 KiB
Python
import asyncio
|
||
import logging
|
||
from sqlalchemy import text
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||
from sqlalchemy.future import select
|
||
|
||
# 导入所有模型
|
||
from app.models import Base, Company, Report, ExtractedContent, AnalysisResult, TaskLog
|
||
from app.config import settings
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 旧数据库 (SQLite)
|
||
SQLITE_URL = "sqlite+aiosqlite:///./qingbao.db"
|
||
|
||
# 新数据库 (PostgreSQL) - 从配置读取
|
||
PG_URL = settings.DATABASE_URL
|
||
# 如果上面读取不正确,请手动取消下面的注释并修改
|
||
# PG_URL = "postgresql+asyncpg://user_xREpkJ:password_DZz8DQ@331002.xyz:2022/qingbao"
|
||
|
||
async def migrate():
|
||
logger.info("开始数据迁移...")
|
||
|
||
# 创建引擎
|
||
sqlite_engine = create_async_engine(SQLITE_URL)
|
||
pg_engine = create_async_engine(PG_URL)
|
||
|
||
async with sqlite_engine.connect() as sqlite_conn, pg_engine.connect() as pg_conn:
|
||
# 0. 确认连接
|
||
logger.info("正在连接数据库...")
|
||
|
||
# 定义需要迁移的表(按顺序)
|
||
# Company, Report, ExtractedContent, AnalysisResult, TaskLog
|
||
# 使用 SQLAlchemy Core 的 insert() 可以在不同方言间自动转换类型
|
||
from sqlalchemy import insert
|
||
|
||
# 1. Companies
|
||
logger.info("正在迁移 Companies...")
|
||
result = await sqlite_conn.execute(select(Company))
|
||
companies = result.all() # ORM 模型对象或者 Row 对象
|
||
|
||
if companies:
|
||
# 将 Row 对象转换为字典
|
||
# 注意:从 SQLite 读出来的 datetime 如果是字符串,PG 可能需要 parsing
|
||
# 但如果我们用 ORM 模型定义了 DateTime,SQLAlchemy 应该已经处理了
|
||
|
||
# 使用字典插入
|
||
values = []
|
||
for row in companies:
|
||
row_dict = dict(row._mapping)
|
||
# 确保布尔值正确
|
||
if 'is_active' in row_dict:
|
||
row_dict['is_active'] = bool(row_dict['is_active'])
|
||
values.append(row_dict)
|
||
|
||
stmt = insert(Company).values(values)
|
||
# PG 不支持 INSERT OR IGNORE,用 ON CONFLICT DO NOTHING
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
stmt = pg_insert(Company).values(values).on_conflict_do_nothing(index_elements=['id'])
|
||
|
||
await pg_conn.execute(stmt)
|
||
logger.info(f"已迁移 {len(values)} 家公司")
|
||
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('companies', 'id'), (SELECT MAX(id) FROM companies))"))
|
||
|
||
# 2. Reports
|
||
logger.info("正在迁移 Reports...")
|
||
result = await sqlite_conn.execute(select(Report))
|
||
reports = result.all()
|
||
|
||
if reports:
|
||
values = []
|
||
for row in reports:
|
||
row_dict = dict(row._mapping)
|
||
# 转换布尔值
|
||
for bool_field in ['is_downloaded', 'is_extracted', 'is_analyzed']:
|
||
if bool_field in row_dict:
|
||
row_dict[bool_field] = bool(row_dict[bool_field])
|
||
values.append(row_dict)
|
||
|
||
stmt = pg_insert(Report).values(values).on_conflict_do_nothing(index_elements=['id'])
|
||
await pg_conn.execute(stmt)
|
||
logger.info(f"已迁移 {len(values)} 份报告")
|
||
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('reports', 'id'), (SELECT MAX(id) FROM reports))"))
|
||
|
||
# 3. ExtractedContent
|
||
logger.info("正在迁移 ExtractedContent...")
|
||
result = await sqlite_conn.execute(select(ExtractedContent))
|
||
contents = result.all()
|
||
|
||
if contents:
|
||
values = []
|
||
for row in contents:
|
||
row_dict = dict(row._mapping)
|
||
values.append(row_dict)
|
||
|
||
# 分批
|
||
batch_size = 50
|
||
for i in range(0, len(values), batch_size):
|
||
batch = values[i:i+batch_size]
|
||
stmt = pg_insert(ExtractedContent).values(batch).on_conflict_do_nothing(index_elements=['id'])
|
||
await pg_conn.execute(stmt)
|
||
logger.info(f" - 已插入 {min(i+batch_size, len(values))}/{len(values)}")
|
||
|
||
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('extracted_contents', 'id'), (SELECT MAX(id) FROM extracted_contents))"))
|
||
|
||
# 4. AnalysisResult
|
||
logger.info("正在迁移 AnalysisResult...")
|
||
result = await sqlite_conn.execute(select(AnalysisResult))
|
||
results = result.all()
|
||
|
||
if results:
|
||
values = []
|
||
for row in results:
|
||
row_dict = dict(row._mapping)
|
||
if 'is_final' in row_dict:
|
||
row_dict['is_final'] = bool(row_dict['is_final'])
|
||
values.append(row_dict)
|
||
|
||
stmt = pg_insert(AnalysisResult).values(values).on_conflict_do_nothing(index_elements=['id'])
|
||
await pg_conn.execute(stmt)
|
||
logger.info(f"已迁移 {len(values)} 条分析结果")
|
||
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('analysis_results', 'id'), (SELECT MAX(id) FROM analysis_results))"))
|
||
|
||
# 5. TaskLog
|
||
logger.info("正在迁移 TaskLog...")
|
||
result = await sqlite_conn.execute(text("SELECT * FROM task_logs")) # TaskLog 可能没有对应 ORM 或者有变化,用 text 读
|
||
# 注意:这里如果用 text 读,列名需要匹配 TaskLog 模型
|
||
# 为了安全,我们还是尽量用 ORM
|
||
# 如果 TaskLog 定义了
|
||
try:
|
||
result = await sqlite_conn.execute(select(TaskLog))
|
||
logs = result.all()
|
||
if logs:
|
||
values = [dict(row._mapping) for row in logs]
|
||
stmt = pg_insert(TaskLog).values(values).on_conflict_do_nothing(index_elements=['id'])
|
||
await pg_conn.execute(stmt)
|
||
logger.info(f"已迁移 {len(values)} 条日志")
|
||
except Exception as e:
|
||
logger.warning(f"TaskLog 迁移跳过(可能表结构不一致): {e}")
|
||
|
||
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('task_logs', 'id'), (SELECT MAX(id) FROM task_logs))"))
|
||
|
||
await pg_conn.commit()
|
||
|
||
await sqlite_engine.dispose()
|
||
await pg_engine.dispose()
|
||
logger.info("数据迁移完成!")
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(migrate())
|