327 lines
13 KiB
Python
327 lines
13 KiB
Python
import feedparser
|
||
import requests
|
||
from datetime import datetime
|
||
import pandas as pd
|
||
import os
|
||
import pickle
|
||
import time
|
||
import sys
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from loguru import logger
|
||
from typing import Dict, List, Optional, Any
|
||
|
||
# Add the parent directory to the Python path to find utils module
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
parent_dir = os.path.dirname(current_dir)
|
||
if parent_dir not in sys.path:
|
||
sys.path.insert(0, parent_dir)
|
||
|
||
from utils.mysql_agent import MySQLAgent
|
||
|
||
# 数据库连接配置
|
||
local_DB_Config = {
|
||
'host': "localhost",
|
||
'user': "root",
|
||
'password': "123123",
|
||
'database': "intelligence_system",
|
||
'port': 3306,
|
||
'charset': 'utf8mb4',
|
||
'connect_timeout': 10,
|
||
'read_timeout': 30,
|
||
'write_timeout': 30,
|
||
'autocommit': True
|
||
}
|
||
|
||
# 目标数据表名
|
||
table_name = "collector_rss_subscriptions"
|
||
|
||
|
||
class NewsAPIClient:
|
||
"""新闻API客户端,用于获取和处理RSS源数据并写入数据库"""
|
||
|
||
def __init__(self):
|
||
"""初始化客户端并建立数据库连接"""
|
||
self.logger = logger.bind(module="NewsAPIClient")
|
||
self.db_agent = MySQLAgent(local_DB_Config)
|
||
self.logger.info("新闻API客户端初始化完成,已连接到数据库")
|
||
|
||
def _format_result(self, success: bool, message: str = "", data: Optional[Any] = None) -> Dict[str, Any]:
|
||
"""统一返回结果格式"""
|
||
return {
|
||
'success': bool(success),
|
||
'message': str(message),
|
||
'data': data
|
||
}
|
||
|
||
def verify_database(self) -> bool:
|
||
"""验证数据库表结构是否符合要求(适配元组格式的查询结果)"""
|
||
try:
|
||
# 1. 检查表是否存在(execute_sql返回元组列表,如 [(table_name,)])
|
||
result = self.db_agent.execute_sql(
|
||
f"SHOW TABLES LIKE '{table_name}'",
|
||
fetch=True
|
||
)
|
||
# 元组结果需通过索引0判断(若表存在,result是[(table_name,)], 否则为空列表)
|
||
if not result:
|
||
self.logger.error(f"表 {table_name} 不存在,请先创建表结构")
|
||
return False
|
||
|
||
# 2. 检查表字段是否完整(DESCRIBE返回的元组格式:(字段名, 类型, 是否为空, ...))
|
||
desc_result = self.db_agent.execute_sql(
|
||
f"DESCRIBE {table_name}",
|
||
fetch=True
|
||
)
|
||
# 关键修改:用元组索引0提取字段名(而非字典键'Field')
|
||
columns = [col[0] for col in desc_result] # col是元组,col[0]即字段名
|
||
required_columns = ['文章标题', '文章链接', '文章摘要', '发布时间',
|
||
'来源URL', '创建时间', '更新时间']
|
||
missing_cols = [col for col in required_columns if col not in columns]
|
||
|
||
if missing_cols:
|
||
self.logger.error(f"表 {table_name} 缺少必要字段:{missing_cols}")
|
||
return False
|
||
|
||
self.logger.info(f"数据库表结构验证通过,当前字段:{columns}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"数据库验证失败: {str(e)}", exc_info=True)
|
||
return False
|
||
|
||
def load_last_update_time(self) -> Optional[datetime]:
|
||
"""加载上次更新时间缓存"""
|
||
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
|
||
if os.path.exists(cache_file):
|
||
try:
|
||
with open(cache_file, 'rb') as f:
|
||
last_update = pickle.load(f)
|
||
self.logger.debug(f"加载上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
return last_update
|
||
except Exception as e:
|
||
self.logger.error(f"加载上次更新时间失败: {str(e)}", exc_info=True)
|
||
self.logger.debug("未找到上次更新时间缓存,将获取全部数据")
|
||
return None
|
||
|
||
def save_last_update_time(self, last_update: datetime) -> None:
|
||
"""保存本次更新时间"""
|
||
try:
|
||
cache_dir = os.path.join(os.getcwd(), 'output')
|
||
os.makedirs(cache_dir, exist_ok=True)
|
||
cache_file = os.path.join(cache_dir, 'last_update.pkl')
|
||
|
||
with open(cache_file, 'wb') as f:
|
||
pickle.dump(last_update, f)
|
||
self.logger.debug(f"已保存本次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
except Exception as e:
|
||
self.logger.error(f"保存更新时间失败: {str(e)}", exc_info=True)
|
||
|
||
def fetch_single_rss(self, url: str, timeout: int = 15) -> Optional[feedparser.FeedParserDict]:
|
||
"""获取并解析单个RSS源"""
|
||
headers = {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||
}
|
||
|
||
for attempt in range(3):
|
||
try:
|
||
response = requests.get(url, headers=headers, timeout=timeout)
|
||
response.raise_for_status()
|
||
response.encoding = response.apparent_encoding
|
||
feed = feedparser.parse(response.text)
|
||
|
||
if feed.bozo:
|
||
self.logger.warning(f"解析 {url} 存在潜在问题: {feed.bozo_exception}")
|
||
|
||
self.logger.debug(f"成功获取 {url} 的RSS数据")
|
||
return feed
|
||
|
||
except requests.RequestException as e:
|
||
self.logger.warning(f"第 {attempt + 1} 次获取 {url} 失败: {str(e)}")
|
||
if attempt < 2:
|
||
time.sleep(3 * (attempt + 1)) # 指数退避重试
|
||
continue
|
||
|
||
self.logger.error(f"三次尝试后仍无法获取 {url} 的RSS数据")
|
||
return None
|
||
|
||
def fetch_all_rss(self, urls: List[str], timeout: int = 15) -> Dict[str, feedparser.FeedParserDict]:
|
||
"""并发获取多个RSS源"""
|
||
feeds = {}
|
||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||
future_to_url = {executor.submit(self.fetch_single_rss, url, timeout): url for url in urls}
|
||
|
||
for future in as_completed(future_to_url):
|
||
url = future_to_url[future]
|
||
try:
|
||
feed = future.result()
|
||
if feed:
|
||
feeds[url] = feed
|
||
except Exception as e:
|
||
self.logger.error(f"处理 {url} 时发生异常: {str(e)}", exc_info=True)
|
||
|
||
self.logger.info(f"RSS源获取完成,成功获取 {len(feeds)}/{len(urls)} 个源")
|
||
return feeds
|
||
|
||
def process_feed_entry(self, entry: Dict[str, Any], url: str) -> Dict[str, str]:
|
||
"""处理单个RSS条目,转换为数据库兼容格式"""
|
||
# 处理标题
|
||
title = entry.get('title', '无标题') or '无标题'
|
||
if len(title) > 255:
|
||
title = title[:252] + '...'
|
||
|
||
# 处理链接
|
||
link = entry.get('link', '无链接') or '无链接'
|
||
if len(link) > 1024:
|
||
link = link[:1021] + '...'
|
||
|
||
# 处理摘要
|
||
summary = entry.get('summary', '无内容摘要')
|
||
content_list = entry.get('content', [])
|
||
content = content_list[0].value if (content_list and hasattr(content_list[0], 'value')) else ''
|
||
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
|
||
|
||
# 处理发布时间
|
||
published_parsed = entry.get('published_parsed') or entry.get('updated_parsed')
|
||
if published_parsed:
|
||
entry_time = datetime(*published_parsed[:6])
|
||
else:
|
||
pub_str = entry.get('published', entry.get('updated', ''))
|
||
try:
|
||
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z').replace(tzinfo=None)
|
||
except:
|
||
entry_time = datetime.now()
|
||
|
||
# 处理来源URL
|
||
source_url = url or '未知来源'
|
||
if len(source_url) > 1024:
|
||
source_url = source_url[:1021] + '...'
|
||
|
||
# 当前时间(创建/更新时间)
|
||
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
|
||
return {
|
||
'文章标题': title,
|
||
'文章链接': link,
|
||
'文章摘要': description,
|
||
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'来源URL': source_url,
|
||
'创建时间': current_time,
|
||
'更新时间': current_time
|
||
}
|
||
|
||
def display_feed_info(self, feed: feedparser.FeedParserDict, last_update: Optional[datetime] = None,
|
||
url: Optional[str] = None) -> Optional[datetime]:
|
||
"""处理RSS源信息并写入数据库"""
|
||
if not feed:
|
||
self.logger.warning("无法处理空的RSS源数据")
|
||
return None
|
||
|
||
self.logger.info(f"开始处理 RSS 源: {url}")
|
||
entries = feed.entries
|
||
data_list = []
|
||
new_last_update = last_update
|
||
|
||
for i, entry in enumerate(entries, 1):
|
||
entry_data = self.process_feed_entry(entry, url)
|
||
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
|
||
|
||
# 过滤旧数据
|
||
if last_update and entry_time <= last_update:
|
||
continue
|
||
|
||
# 更新最新时间戳
|
||
if new_last_update is None or entry_time > new_last_update:
|
||
new_last_update = entry_time
|
||
|
||
self.logger.debug(f"处理条目 {i}: {entry_data['文章标题']}")
|
||
data_list.append(entry_data)
|
||
|
||
# 写入数据库
|
||
if data_list:
|
||
df = pd.DataFrame(data_list)
|
||
self.write_to_database(df)
|
||
|
||
return new_last_update
|
||
|
||
# rss_subscriptions.py 中的 write_to_database 方法可以保持简洁
|
||
def write_to_database(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||
if df.empty:
|
||
self.logger.info("没有新数据需要写入数据库")
|
||
return self._format_result(True, "没有新数据需要写入")
|
||
|
||
try:
|
||
inserted_rows = self.db_agent.insert_from_df(
|
||
table_name=table_name,
|
||
df=df,
|
||
chunk_size=500,
|
||
ignore_duplicates=True
|
||
)
|
||
|
||
self.logger.info(f"成功写入 {inserted_rows}/{len(df)} 条记录")
|
||
return self._format_result(
|
||
True,
|
||
f"成功写入 {inserted_rows}/{len(df)} 条记录",
|
||
{"success_count": inserted_rows, "total": len(df)}
|
||
)
|
||
|
||
except Exception as e:
|
||
self.logger.error(
|
||
"数据库写入失败",
|
||
error=str(e),
|
||
error_type=type(e).__name__,
|
||
table_name=table_name,
|
||
record_count=len(df),
|
||
sample_records=df.head(2).to_dict('records') if not df.empty else [],
|
||
exc_info=True
|
||
)
|
||
return self._format_result(False, f"数据库操作失败: {str(e)}")
|
||
@classmethod
|
||
def main(cls):
|
||
"""主函数入口"""
|
||
try:
|
||
client = cls()
|
||
|
||
# 验证数据库
|
||
if not client.verify_database():
|
||
client.logger.error("数据库验证失败,程序终止")
|
||
return
|
||
|
||
# RSS源列表
|
||
rss_urls = [
|
||
"https://www.chinanews.com.cn/rss/finance.xml",
|
||
"https://www.chinanews.com.cn/rss/world.xml",
|
||
"https://www.chinanews.com.cn/rss/china.xml",
|
||
"https://www.chinanews.com.cn/rss/scroll-news.xml"
|
||
]
|
||
|
||
# 加载上次更新时间
|
||
last_update = client.load_last_update_time()
|
||
if last_update:
|
||
client.logger.info(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
|
||
# 获取RSS数据
|
||
client.logger.info("开始获取RSS源数据...")
|
||
start_time = time.time()
|
||
feeds = client.fetch_all_rss(rss_urls)
|
||
client.logger.info(f"获取完成,耗时: {time.time() - start_time:.2f}秒")
|
||
|
||
# 处理并写入数据
|
||
new_last_update = None
|
||
for url, feed in feeds.items():
|
||
current_last_update = client.display_feed_info(feed, last_update, url)
|
||
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
|
||
new_last_update = current_last_update
|
||
|
||
# 保存最新更新时间
|
||
if new_last_update:
|
||
client.save_last_update_time(new_last_update)
|
||
client.logger.info(f"本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
else:
|
||
client.logger.info("没有获取到新内容")
|
||
|
||
except Exception as e:
|
||
logger.error(f"程序运行出错: {str(e)}", exc_info=True)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
NewsAPIClient.main()
|