import asyncio import logging from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.future import select # 导入所有模型 from app.models import Base, Company, Report, ExtractedContent, AnalysisResult, TaskLog from app.config import settings # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 旧数据库 (SQLite) SQLITE_URL = "sqlite+aiosqlite:///./qingbao.db" # 新数据库 (PostgreSQL) - 从配置读取 PG_URL = settings.DATABASE_URL # 如果上面读取不正确,请手动取消下面的注释并修改 # PG_URL = "postgresql+asyncpg://user_xREpkJ:password_DZz8DQ@331002.xyz:2022/qingbao" async def migrate(): logger.info("开始数据迁移...") # 创建引擎 sqlite_engine = create_async_engine(SQLITE_URL) pg_engine = create_async_engine(PG_URL) async with sqlite_engine.connect() as sqlite_conn, pg_engine.connect() as pg_conn: # 0. 确认连接 logger.info("正在连接数据库...") # 定义需要迁移的表(按顺序) # Company, Report, ExtractedContent, AnalysisResult, TaskLog # 使用 SQLAlchemy Core 的 insert() 可以在不同方言间自动转换类型 from sqlalchemy import insert # 1. Companies logger.info("正在迁移 Companies...") result = await sqlite_conn.execute(select(Company)) companies = result.all() # ORM 模型对象或者 Row 对象 if companies: # 将 Row 对象转换为字典 # 注意:从 SQLite 读出来的 datetime 如果是字符串,PG 可能需要 parsing # 但如果我们用 ORM 模型定义了 DateTime,SQLAlchemy 应该已经处理了 # 使用字典插入 values = [] for row in companies: row_dict = dict(row._mapping) # 确保布尔值正确 if 'is_active' in row_dict: row_dict['is_active'] = bool(row_dict['is_active']) values.append(row_dict) stmt = insert(Company).values(values) # PG 不支持 INSERT OR IGNORE,用 ON CONFLICT DO NOTHING from sqlalchemy.dialects.postgresql import insert as pg_insert stmt = pg_insert(Company).values(values).on_conflict_do_nothing(index_elements=['id']) await pg_conn.execute(stmt) logger.info(f"已迁移 {len(values)} 家公司") await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('companies', 'id'), (SELECT MAX(id) FROM companies))")) # 2. Reports logger.info("正在迁移 Reports...") result = await sqlite_conn.execute(select(Report)) reports = result.all() if reports: values = [] for row in reports: row_dict = dict(row._mapping) # 转换布尔值 for bool_field in ['is_downloaded', 'is_extracted', 'is_analyzed']: if bool_field in row_dict: row_dict[bool_field] = bool(row_dict[bool_field]) values.append(row_dict) stmt = pg_insert(Report).values(values).on_conflict_do_nothing(index_elements=['id']) await pg_conn.execute(stmt) logger.info(f"已迁移 {len(values)} 份报告") await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('reports', 'id'), (SELECT MAX(id) FROM reports))")) # 3. ExtractedContent logger.info("正在迁移 ExtractedContent...") result = await sqlite_conn.execute(select(ExtractedContent)) contents = result.all() if contents: values = [] for row in contents: row_dict = dict(row._mapping) values.append(row_dict) # 分批 batch_size = 50 for i in range(0, len(values), batch_size): batch = values[i:i+batch_size] stmt = pg_insert(ExtractedContent).values(batch).on_conflict_do_nothing(index_elements=['id']) await pg_conn.execute(stmt) logger.info(f" - 已插入 {min(i+batch_size, len(values))}/{len(values)}") await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('extracted_contents', 'id'), (SELECT MAX(id) FROM extracted_contents))")) # 4. AnalysisResult logger.info("正在迁移 AnalysisResult...") result = await sqlite_conn.execute(select(AnalysisResult)) results = result.all() if results: values = [] for row in results: row_dict = dict(row._mapping) if 'is_final' in row_dict: row_dict['is_final'] = bool(row_dict['is_final']) values.append(row_dict) stmt = pg_insert(AnalysisResult).values(values).on_conflict_do_nothing(index_elements=['id']) await pg_conn.execute(stmt) logger.info(f"已迁移 {len(values)} 条分析结果") await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('analysis_results', 'id'), (SELECT MAX(id) FROM analysis_results))")) # 5. TaskLog logger.info("正在迁移 TaskLog...") result = await sqlite_conn.execute(text("SELECT * FROM task_logs")) # TaskLog 可能没有对应 ORM 或者有变化,用 text 读 # 注意:这里如果用 text 读,列名需要匹配 TaskLog 模型 # 为了安全,我们还是尽量用 ORM # 如果 TaskLog 定义了 try: result = await sqlite_conn.execute(select(TaskLog)) logs = result.all() if logs: values = [dict(row._mapping) for row in logs] stmt = pg_insert(TaskLog).values(values).on_conflict_do_nothing(index_elements=['id']) await pg_conn.execute(stmt) logger.info(f"已迁移 {len(values)} 条日志") except Exception as e: logger.warning(f"TaskLog 迁移跳过(可能表结构不一致): {e}") await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('task_logs', 'id'), (SELECT MAX(id) FROM task_logs))")) await pg_conn.commit() await sqlite_engine.dispose() await pg_engine.dispose() logger.info("数据迁移完成!") if __name__ == "__main__": asyncio.run(migrate())