huibao/backend/app/services/pdf_extractor.py

145 lines
5.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
PDF内容提取服务 - 增强版
提取更多内容,更深度的分析
"""
import re
from typing import List, Dict, Optional, Tuple
import fitz
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.utils.logger import logger
from app.models import Report, ExtractedContent
class PDFExtractor:
"""PDF内容提取器 - 增强版"""
def __init__(self):
self.keywords = settings.EXTRACT_KEYWORDS
# 增加更多提取行数(上下文更丰富)
self.context_lines_before = 5
self.context_lines_after = 80 # 增加到80行获取更完整的章节内容
self.min_content_length = 200 # 最小内容长度
def extract_text_from_pdf(self, pdf_path: str) -> Tuple[str, int]:
"""从PDF提取全部文本"""
try:
doc = fitz.open(pdf_path)
text = ""
page_count = doc.page_count
for page in doc:
text += page.get_text()
doc.close()
logger.info(f"PDF提取完成: {page_count} 页, {len(text)} 字符")
return text, page_count
except Exception as e:
logger.error(f"PDF提取失败: {e}")
return "", 0
def extract_sections_by_keywords(self, text: str) -> List[Dict]:
"""根据关键词提取相关章节 - 增强版,提取更多内容"""
sections = []
lines = text.split('\n')
total_lines = len(lines)
# 记录已匹配的行范围,避免重复提取
matched_ranges = []
for keyword in self.keywords:
for i, line in enumerate(lines):
if keyword in line:
# 检查是否在已匹配范围内
is_overlapping = False
for start_r, end_r in matched_ranges:
if start_r <= i <= end_r:
is_overlapping = True
break
if is_overlapping:
continue
# 扩大提取范围
start = max(0, i - self.context_lines_before)
end = min(total_lines, i + self.context_lines_after)
# 智能查找章节结束位置(遇到下一个大标题时停止)
for j in range(i + 10, end):
next_line = lines[j].strip()
# 如果遇到新的章节标题(如"第X节"、"X"等),停止
if re.match(r'^[(]?[一二三四五六七八九十\d]+[)]', next_line):
if len(next_line) < 50: # 确保是标题而非正文
end = j
break
content_lines = [l.strip() for l in lines[start:end] if l.strip()]
content = '\n'.join(content_lines)
if len(content) > self.min_content_length:
sections.append({
"keyword": keyword,
"section_name": line.strip()[:100],
"content": content,
"char_count": len(content),
"line_start": start,
"line_end": end
})
matched_ranges.append((start, end))
dedupe_sections = self._dedupe_sections(sections)
logger.info(f"提取到 {len(dedupe_sections)} 个章节")
return dedupe_sections
def _dedupe_sections(self, sections: List[Dict]) -> List[Dict]:
"""去重并按内容丰富度排序"""
seen = set()
result = []
for s in sections:
# 使用前300字符作为去重依据
key = s["content"][:300]
if key not in seen:
seen.add(key)
result.append(s)
# 按内容长度排序,优先保留内容更丰富的章节
result.sort(key=lambda x: x["char_count"], reverse=True)
return result
async def extract_and_save(self, db: AsyncSession, report: Report) -> List[ExtractedContent]:
"""提取并保存到数据库"""
import asyncio
if not report.local_path:
logger.warning(f"报告无本地路径: {report.title}")
return []
# 将 PDF 提取这种阻塞 CPU 的任务放入线程池
logger.info(f"开启线程提取 PDF: {report.title}")
text, page_count = await asyncio.to_thread(self.extract_text_from_pdf, report.local_path)
if not text:
return []
# 将关键词匹配这种阻塞 CPU 的任务也放入线程池
logger.info(f"开启线程提取章节: {report.title}")
sections = await asyncio.to_thread(self.extract_sections_by_keywords, text)
contents = []
for section in sections:
content = ExtractedContent(
report_id=report.id,
section_name=section["section_name"],
section_keyword=section["keyword"],
content=section["content"],
char_count=section["char_count"]
)
db.add(content)
contents.append(content)
report.is_extracted = True
await db.commit()
logger.info(f"保存 {len(contents)} 个章节到数据库")
return contents
pdf_extractor = PDFExtractor()