huibao/backend/migrate_data.py

153 lines
6.5 KiB
Python
Raw Permalink Normal View History

import asyncio
import logging
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.future import select
# 导入所有模型
from app.models import Base, Company, Report, ExtractedContent, AnalysisResult, TaskLog
from app.config import settings
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 旧数据库 (SQLite)
SQLITE_URL = "sqlite+aiosqlite:///./qingbao.db"
# 新数据库 (PostgreSQL) - 从配置读取
PG_URL = settings.DATABASE_URL
# 如果上面读取不正确,请手动取消下面的注释并修改
# PG_URL = "postgresql+asyncpg://user_xREpkJ:password_DZz8DQ@331002.xyz:2022/qingbao"
async def migrate():
logger.info("开始数据迁移...")
# 创建引擎
sqlite_engine = create_async_engine(SQLITE_URL)
pg_engine = create_async_engine(PG_URL)
async with sqlite_engine.connect() as sqlite_conn, pg_engine.connect() as pg_conn:
# 0. 确认连接
logger.info("正在连接数据库...")
# 定义需要迁移的表(按顺序)
# Company, Report, ExtractedContent, AnalysisResult, TaskLog
# 使用 SQLAlchemy Core 的 insert() 可以在不同方言间自动转换类型
from sqlalchemy import insert
# 1. Companies
logger.info("正在迁移 Companies...")
result = await sqlite_conn.execute(select(Company))
companies = result.all() # ORM 模型对象或者 Row 对象
if companies:
# 将 Row 对象转换为字典
# 注意:从 SQLite 读出来的 datetime 如果是字符串PG 可能需要 parsing
# 但如果我们用 ORM 模型定义了 DateTimeSQLAlchemy 应该已经处理了
# 使用字典插入
values = []
for row in companies:
row_dict = dict(row._mapping)
# 确保布尔值正确
if 'is_active' in row_dict:
row_dict['is_active'] = bool(row_dict['is_active'])
values.append(row_dict)
stmt = insert(Company).values(values)
# PG 不支持 INSERT OR IGNORE用 ON CONFLICT DO NOTHING
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(Company).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 家公司")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('companies', 'id'), (SELECT MAX(id) FROM companies))"))
# 2. Reports
logger.info("正在迁移 Reports...")
result = await sqlite_conn.execute(select(Report))
reports = result.all()
if reports:
values = []
for row in reports:
row_dict = dict(row._mapping)
# 转换布尔值
for bool_field in ['is_downloaded', 'is_extracted', 'is_analyzed']:
if bool_field in row_dict:
row_dict[bool_field] = bool(row_dict[bool_field])
values.append(row_dict)
stmt = pg_insert(Report).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 份报告")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('reports', 'id'), (SELECT MAX(id) FROM reports))"))
# 3. ExtractedContent
logger.info("正在迁移 ExtractedContent...")
result = await sqlite_conn.execute(select(ExtractedContent))
contents = result.all()
if contents:
values = []
for row in contents:
row_dict = dict(row._mapping)
values.append(row_dict)
# 分批
batch_size = 50
for i in range(0, len(values), batch_size):
batch = values[i:i+batch_size]
stmt = pg_insert(ExtractedContent).values(batch).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f" - 已插入 {min(i+batch_size, len(values))}/{len(values)}")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('extracted_contents', 'id'), (SELECT MAX(id) FROM extracted_contents))"))
# 4. AnalysisResult
logger.info("正在迁移 AnalysisResult...")
result = await sqlite_conn.execute(select(AnalysisResult))
results = result.all()
if results:
values = []
for row in results:
row_dict = dict(row._mapping)
if 'is_final' in row_dict:
row_dict['is_final'] = bool(row_dict['is_final'])
values.append(row_dict)
stmt = pg_insert(AnalysisResult).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 条分析结果")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('analysis_results', 'id'), (SELECT MAX(id) FROM analysis_results))"))
# 5. TaskLog
logger.info("正在迁移 TaskLog...")
result = await sqlite_conn.execute(text("SELECT * FROM task_logs")) # TaskLog 可能没有对应 ORM 或者有变化,用 text 读
# 注意:这里如果用 text 读,列名需要匹配 TaskLog 模型
# 为了安全,我们还是尽量用 ORM
# 如果 TaskLog 定义了
try:
result = await sqlite_conn.execute(select(TaskLog))
logs = result.all()
if logs:
values = [dict(row._mapping) for row in logs]
stmt = pg_insert(TaskLog).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 条日志")
except Exception as e:
logger.warning(f"TaskLog 迁移跳过(可能表结构不一致): {e}")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('task_logs', 'id'), (SELECT MAX(id) FROM task_logs))"))
await pg_conn.commit()
await sqlite_engine.dispose()
await pg_engine.dispose()
logger.info("数据迁移完成!")
if __name__ == "__main__":
asyncio.run(migrate())