""" 公司管理API """ from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from app.database import get_db from app.models import Company, Report from app.schemas import CompanyCreate, CompanyUpdate, CompanyResponse, MessageResponse router = APIRouter(prefix="/api/companies", tags=["公司管理"]) @router.get("", response_model=List[CompanyResponse]) async def list_companies( industry: Optional[str] = None, is_active: Optional[bool] = None, db: AsyncSession = Depends(get_db) ): """获取公司列表""" stmt = select(Company) if industry: stmt = stmt.where(Company.industry == industry) if is_active is not None: stmt = stmt.where(Company.is_active == is_active) stmt = stmt.order_by(Company.created_at.desc()) result = await db.execute(stmt) companies = result.scalars().all() # 统计报告数量 response = [] for company in companies: count_stmt = select(func.count(Report.id)).where(Report.company_id == company.id) count_result = await db.execute(count_stmt) report_count = count_result.scalar() company_data = CompanyResponse.model_validate(company) company_data.report_count = report_count response.append(company_data) return response @router.post("", response_model=CompanyResponse) async def create_company( company: CompanyCreate, db: AsyncSession = Depends(get_db) ): """创建公司 - 仅需股票代码,自动从巨潮获取公司信息""" from app.services.cninfo_crawler import cninfo_service stock_code = company.stock_code.strip() # 检查股票代码是否已存在 stmt = select(Company).where(Company.stock_code == stock_code) result = await db.execute(stmt) existing = result.scalar_one_or_none() if existing: raise HTTPException(status_code=400, detail="股票代码已存在") # 从巨潮搜索API获取公司信息 import httpx try: search_url = "https://www.cninfo.com.cn/new/information/topSearch/query" headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", "Content-Type": "application/x-www-form-urlencoded; charset=UTF-8", } search_data = {"keyWord": stock_code, "maxNum": 5} async with httpx.AsyncClient(timeout=10.0, verify=False) as client: response = await client.post(search_url, data=search_data, headers=headers) if response.status_code == 200: results = response.json() company_info = None for item in results: if item.get("code") == stock_code: company_info = item break if company_info: # 自动填充公司信息 db_company = Company( stock_code=stock_code, company_name=company_info.get("zwjc", company.company_name or stock_code), # 中文简称作为公司名 short_name=company_info.get("zwjc", company.short_name), industry=company.industry or "未知", org_id=company_info.get("orgId") # 保存orgId方便后续查询 ) else: # 巨潮未找到,使用用户提供的信息 if not company.company_name: raise HTTPException(status_code=400, detail=f"未找到股票代码 {stock_code} 的公司信息,请手动填写公司名称") db_company = Company( stock_code=stock_code, company_name=company.company_name, short_name=company.short_name or company.company_name[:4], industry=company.industry ) else: raise HTTPException(status_code=500, detail="巨潮API请求失败") except httpx.RequestError as e: # 网络错误时,使用用户提供的信息 if not company.company_name: raise HTTPException(status_code=400, detail=f"网络错误,请手动填写公司名称") db_company = Company( stock_code=stock_code, company_name=company.company_name, short_name=company.short_name or company.company_name[:4], industry=company.industry ) db.add(db_company) await db.commit() await db.refresh(db_company) return CompanyResponse.model_validate(db_company) @router.get("/{company_id}", response_model=CompanyResponse) async def get_company(company_id: int, db: AsyncSession = Depends(get_db)): """获取公司详情""" stmt = select(Company).where(Company.id == company_id) result = await db.execute(stmt) company = result.scalar_one_or_none() if not company: raise HTTPException(status_code=404, detail="公司不存在") return CompanyResponse.model_validate(company) @router.put("/{company_id}", response_model=CompanyResponse) async def update_company( company_id: int, update: CompanyUpdate, db: AsyncSession = Depends(get_db) ): """更新公司信息""" stmt = select(Company).where(Company.id == company_id) result = await db.execute(stmt) company = result.scalar_one_or_none() if not company: raise HTTPException(status_code=404, detail="公司不存在") for key, value in update.model_dump(exclude_unset=True).items(): setattr(company, key, value) await db.commit() await db.refresh(company) return CompanyResponse.model_validate(company) @router.delete("/{company_id}", response_model=MessageResponse) async def delete_company(company_id: int, db: AsyncSession = Depends(get_db)): """删除公司""" stmt = select(Company).where(Company.id == company_id) result = await db.execute(stmt) company = result.scalar_one_or_none() if not company: raise HTTPException(status_code=404, detail="公司不存在") await db.delete(company) await db.commit() return MessageResponse(success=True, message="删除成功")