177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
|
|
"""
|
||
|
|
公司管理API
|
||
|
|
"""
|
||
|
|
from typing import List, Optional
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy import select, func
|
||
|
|
|
||
|
|
from app.database import get_db
|
||
|
|
from app.models import Company, Report
|
||
|
|
from app.schemas import CompanyCreate, CompanyUpdate, CompanyResponse, MessageResponse
|
||
|
|
|
||
|
|
router = APIRouter(prefix="/api/companies", tags=["公司管理"])
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("", response_model=List[CompanyResponse])
|
||
|
|
async def list_companies(
|
||
|
|
industry: Optional[str] = None,
|
||
|
|
is_active: Optional[bool] = None,
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""获取公司列表"""
|
||
|
|
stmt = select(Company)
|
||
|
|
|
||
|
|
if industry:
|
||
|
|
stmt = stmt.where(Company.industry == industry)
|
||
|
|
if is_active is not None:
|
||
|
|
stmt = stmt.where(Company.is_active == is_active)
|
||
|
|
|
||
|
|
stmt = stmt.order_by(Company.created_at.desc())
|
||
|
|
result = await db.execute(stmt)
|
||
|
|
companies = result.scalars().all()
|
||
|
|
|
||
|
|
# 统计报告数量
|
||
|
|
response = []
|
||
|
|
for company in companies:
|
||
|
|
count_stmt = select(func.count(Report.id)).where(Report.company_id == company.id)
|
||
|
|
count_result = await db.execute(count_stmt)
|
||
|
|
report_count = count_result.scalar()
|
||
|
|
|
||
|
|
company_data = CompanyResponse.model_validate(company)
|
||
|
|
company_data.report_count = report_count
|
||
|
|
response.append(company_data)
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("", response_model=CompanyResponse)
|
||
|
|
async def create_company(
|
||
|
|
company: CompanyCreate,
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""创建公司 - 仅需股票代码,自动从巨潮获取公司信息"""
|
||
|
|
from app.services.cninfo_crawler import cninfo_service
|
||
|
|
|
||
|
|
stock_code = company.stock_code.strip()
|
||
|
|
|
||
|
|
# 检查股票代码是否已存在
|
||
|
|
stmt = select(Company).where(Company.stock_code == stock_code)
|
||
|
|
result = await db.execute(stmt)
|
||
|
|
existing = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if existing:
|
||
|
|
raise HTTPException(status_code=400, detail="股票代码已存在")
|
||
|
|
|
||
|
|
# 从巨潮搜索API获取公司信息
|
||
|
|
import httpx
|
||
|
|
try:
|
||
|
|
search_url = "https://www.cninfo.com.cn/new/information/topSearch/query"
|
||
|
|
headers = {
|
||
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||
|
|
"Content-Type": "application/x-www-form-urlencoded; charset=UTF-8",
|
||
|
|
}
|
||
|
|
search_data = {"keyWord": stock_code, "maxNum": 5}
|
||
|
|
|
||
|
|
async with httpx.AsyncClient(timeout=10.0, verify=False) as client:
|
||
|
|
response = await client.post(search_url, data=search_data, headers=headers)
|
||
|
|
|
||
|
|
if response.status_code == 200:
|
||
|
|
results = response.json()
|
||
|
|
company_info = None
|
||
|
|
for item in results:
|
||
|
|
if item.get("code") == stock_code:
|
||
|
|
company_info = item
|
||
|
|
break
|
||
|
|
|
||
|
|
if company_info:
|
||
|
|
# 自动填充公司信息
|
||
|
|
db_company = Company(
|
||
|
|
stock_code=stock_code,
|
||
|
|
company_name=company_info.get("zwjc", company.company_name or stock_code), # 中文简称作为公司名
|
||
|
|
short_name=company_info.get("zwjc", company.short_name),
|
||
|
|
industry=company.industry or "未知",
|
||
|
|
org_id=company_info.get("orgId") # 保存orgId方便后续查询
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# 巨潮未找到,使用用户提供的信息
|
||
|
|
if not company.company_name:
|
||
|
|
raise HTTPException(status_code=400, detail=f"未找到股票代码 {stock_code} 的公司信息,请手动填写公司名称")
|
||
|
|
db_company = Company(
|
||
|
|
stock_code=stock_code,
|
||
|
|
company_name=company.company_name,
|
||
|
|
short_name=company.short_name or company.company_name[:4],
|
||
|
|
industry=company.industry
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise HTTPException(status_code=500, detail="巨潮API请求失败")
|
||
|
|
|
||
|
|
except httpx.RequestError as e:
|
||
|
|
# 网络错误时,使用用户提供的信息
|
||
|
|
if not company.company_name:
|
||
|
|
raise HTTPException(status_code=400, detail=f"网络错误,请手动填写公司名称")
|
||
|
|
db_company = Company(
|
||
|
|
stock_code=stock_code,
|
||
|
|
company_name=company.company_name,
|
||
|
|
short_name=company.short_name or company.company_name[:4],
|
||
|
|
industry=company.industry
|
||
|
|
)
|
||
|
|
|
||
|
|
db.add(db_company)
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(db_company)
|
||
|
|
|
||
|
|
return CompanyResponse.model_validate(db_company)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/{company_id}", response_model=CompanyResponse)
|
||
|
|
async def get_company(company_id: int, db: AsyncSession = Depends(get_db)):
|
||
|
|
"""获取公司详情"""
|
||
|
|
stmt = select(Company).where(Company.id == company_id)
|
||
|
|
result = await db.execute(stmt)
|
||
|
|
company = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not company:
|
||
|
|
raise HTTPException(status_code=404, detail="公司不存在")
|
||
|
|
|
||
|
|
return CompanyResponse.model_validate(company)
|
||
|
|
|
||
|
|
|
||
|
|
@router.put("/{company_id}", response_model=CompanyResponse)
|
||
|
|
async def update_company(
|
||
|
|
company_id: int,
|
||
|
|
update: CompanyUpdate,
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""更新公司信息"""
|
||
|
|
stmt = select(Company).where(Company.id == company_id)
|
||
|
|
result = await db.execute(stmt)
|
||
|
|
company = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not company:
|
||
|
|
raise HTTPException(status_code=404, detail="公司不存在")
|
||
|
|
|
||
|
|
for key, value in update.model_dump(exclude_unset=True).items():
|
||
|
|
setattr(company, key, value)
|
||
|
|
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(company)
|
||
|
|
|
||
|
|
return CompanyResponse.model_validate(company)
|
||
|
|
|
||
|
|
|
||
|
|
@router.delete("/{company_id}", response_model=MessageResponse)
|
||
|
|
async def delete_company(company_id: int, db: AsyncSession = Depends(get_db)):
|
||
|
|
"""删除公司"""
|
||
|
|
stmt = select(Company).where(Company.id == company_id)
|
||
|
|
result = await db.execute(stmt)
|
||
|
|
company = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not company:
|
||
|
|
raise HTTPException(status_code=404, detail="公司不存在")
|
||
|
|
|
||
|
|
await db.delete(company)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
return MessageResponse(success=True, message="删除成功")
|