huibao/backend/migrate_data.py

153 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import logging
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.future import select
# 导入所有模型
from app.models import Base, Company, Report, ExtractedContent, AnalysisResult, TaskLog
from app.config import settings
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 旧数据库 (SQLite)
SQLITE_URL = "sqlite+aiosqlite:///./qingbao.db"
# 新数据库 (PostgreSQL) - 从配置读取
PG_URL = settings.DATABASE_URL
# 如果上面读取不正确,请手动取消下面的注释并修改
# PG_URL = "postgresql+asyncpg://user_xREpkJ:password_DZz8DQ@331002.xyz:2022/qingbao"
async def migrate():
logger.info("开始数据迁移...")
# 创建引擎
sqlite_engine = create_async_engine(SQLITE_URL)
pg_engine = create_async_engine(PG_URL)
async with sqlite_engine.connect() as sqlite_conn, pg_engine.connect() as pg_conn:
# 0. 确认连接
logger.info("正在连接数据库...")
# 定义需要迁移的表(按顺序)
# Company, Report, ExtractedContent, AnalysisResult, TaskLog
# 使用 SQLAlchemy Core 的 insert() 可以在不同方言间自动转换类型
from sqlalchemy import insert
# 1. Companies
logger.info("正在迁移 Companies...")
result = await sqlite_conn.execute(select(Company))
companies = result.all() # ORM 模型对象或者 Row 对象
if companies:
# 将 Row 对象转换为字典
# 注意:从 SQLite 读出来的 datetime 如果是字符串PG 可能需要 parsing
# 但如果我们用 ORM 模型定义了 DateTimeSQLAlchemy 应该已经处理了
# 使用字典插入
values = []
for row in companies:
row_dict = dict(row._mapping)
# 确保布尔值正确
if 'is_active' in row_dict:
row_dict['is_active'] = bool(row_dict['is_active'])
values.append(row_dict)
stmt = insert(Company).values(values)
# PG 不支持 INSERT OR IGNORE用 ON CONFLICT DO NOTHING
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(Company).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 家公司")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('companies', 'id'), (SELECT MAX(id) FROM companies))"))
# 2. Reports
logger.info("正在迁移 Reports...")
result = await sqlite_conn.execute(select(Report))
reports = result.all()
if reports:
values = []
for row in reports:
row_dict = dict(row._mapping)
# 转换布尔值
for bool_field in ['is_downloaded', 'is_extracted', 'is_analyzed']:
if bool_field in row_dict:
row_dict[bool_field] = bool(row_dict[bool_field])
values.append(row_dict)
stmt = pg_insert(Report).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 份报告")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('reports', 'id'), (SELECT MAX(id) FROM reports))"))
# 3. ExtractedContent
logger.info("正在迁移 ExtractedContent...")
result = await sqlite_conn.execute(select(ExtractedContent))
contents = result.all()
if contents:
values = []
for row in contents:
row_dict = dict(row._mapping)
values.append(row_dict)
# 分批
batch_size = 50
for i in range(0, len(values), batch_size):
batch = values[i:i+batch_size]
stmt = pg_insert(ExtractedContent).values(batch).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f" - 已插入 {min(i+batch_size, len(values))}/{len(values)}")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('extracted_contents', 'id'), (SELECT MAX(id) FROM extracted_contents))"))
# 4. AnalysisResult
logger.info("正在迁移 AnalysisResult...")
result = await sqlite_conn.execute(select(AnalysisResult))
results = result.all()
if results:
values = []
for row in results:
row_dict = dict(row._mapping)
if 'is_final' in row_dict:
row_dict['is_final'] = bool(row_dict['is_final'])
values.append(row_dict)
stmt = pg_insert(AnalysisResult).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 条分析结果")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('analysis_results', 'id'), (SELECT MAX(id) FROM analysis_results))"))
# 5. TaskLog
logger.info("正在迁移 TaskLog...")
result = await sqlite_conn.execute(text("SELECT * FROM task_logs")) # TaskLog 可能没有对应 ORM 或者有变化,用 text 读
# 注意:这里如果用 text 读,列名需要匹配 TaskLog 模型
# 为了安全,我们还是尽量用 ORM
# 如果 TaskLog 定义了
try:
result = await sqlite_conn.execute(select(TaskLog))
logs = result.all()
if logs:
values = [dict(row._mapping) for row in logs]
stmt = pg_insert(TaskLog).values(values).on_conflict_do_nothing(index_elements=['id'])
await pg_conn.execute(stmt)
logger.info(f"已迁移 {len(values)} 条日志")
except Exception as e:
logger.warning(f"TaskLog 迁移跳过(可能表结构不一致): {e}")
await pg_conn.execute(text("SELECT setval(pg_get_serial_sequence('task_logs', 'id'), (SELECT MAX(id) FROM task_logs))"))
await pg_conn.commit()
await sqlite_engine.dispose()
await pg_engine.dispose()
logger.info("数据迁移完成!")
if __name__ == "__main__":
asyncio.run(migrate())