From 8b5332d8e822e00b6d4a99e4e4f09682bdbd1c3e Mon Sep 17 00:00:00 2001 From: wchiways Date: Wed, 4 Feb 2026 23:24:17 +0800 Subject: [PATCH 1/5] refactor: replace os.path with pathlib and use f-strings --- weibo_spider/__main__.py | 4 +- weibo_spider/config_util.py | 59 +++++++------ .../downloader/avatar_picture_downloader.py | 12 +-- weibo_spider/downloader/downloader.py | 19 +++-- weibo_spider/downloader/img_downloader.py | 20 ++--- weibo_spider/downloader/video_downloader.py | 10 +-- weibo_spider/parser/index_parser.py | 6 +- weibo_spider/parser/info_parser.py | 20 ++--- weibo_spider/parser/page_parser.py | 84 +++++++++---------- weibo_spider/parser/util.py | 23 +++-- weibo_spider/spider.py | 70 +++++++--------- weibo_spider/user.py | 10 +-- weibo_spider/weibo.py | 14 ++-- weibo_spider/writer/csv_writer.py | 2 +- weibo_spider/writer/json_writer.py | 6 +- weibo_spider/writer/kafka_writer.py | 4 +- weibo_spider/writer/mongo_writer.py | 8 +- weibo_spider/writer/mysql_writer.py | 8 +- weibo_spider/writer/post_writer.py | 4 +- weibo_spider/writer/sqlite_writer.py | 4 +- weibo_spider/writer/txt_writer.py | 11 ++- 21 files changed, 192 insertions(+), 206 deletions(-) diff --git a/weibo_spider/__main__.py b/weibo_spider/__main__.py index f1eafa65..8eb4ecb7 100644 --- a/weibo_spider/__main__.py +++ b/weibo_spider/__main__.py @@ -1,8 +1,8 @@ -import os import sys +from pathlib import Path from absl import app -sys.path.append(os.path.abspath(os.path.dirname(os.getcwd()))) +sys.path.append(str(Path.cwd().parent.absolute())) from weibo_spider.spider import main app.run(main) diff --git a/weibo_spider/config_util.py b/weibo_spider/config_util.py index f5dc682d..254b1eea 100644 --- a/weibo_spider/config_util.py +++ b/weibo_spider/config_util.py @@ -1,9 +1,9 @@ import codecs import logging -import os import sys import browser_cookie3 from datetime import datetime +from pathlib import Path import json logger = logging.getLogger('spider.config_util') @@ -28,87 +28,86 @@ def validate_config(config): argument_list = ['filter', 'pic_download', 'video_download'] for argument in argument_list: if config[argument] != 0 and config[argument] != 1: - logger.warning(u'%s值应为0或1,请重新输入', config[argument]) + logger.warning(f'{config[argument]}值应为0或1,请重新输入') sys.exit() # 验证since_date since_date = config['since_date'] if (not _is_date(str(since_date))) and (not isinstance(since_date, int)): - logger.warning(u'since_date值应为yyyy-mm-dd形式或整数,请重新输入') + logger.warning('since_date值应为yyyy-mm-dd形式或整数,请重新输入') sys.exit() # 验证end_date end_date = str(config['end_date']) if (not _is_date(end_date)) and (end_date != 'now'): - logger.warning(u'end_date值应为yyyy-mm-dd形式或"now",请重新输入') + logger.warning('end_date值应为yyyy-mm-dd形式或"now",请重新输入') sys.exit() # 验证random_wait_pages random_wait_pages = config['random_wait_pages'] if not isinstance(random_wait_pages, list): - logger.warning(u'random_wait_pages参数值应为list类型,请重新输入') + logger.warning('random_wait_pages参数值应为list类型,请重新输入') sys.exit() if (not isinstance(min(random_wait_pages), int)) or (not isinstance( max(random_wait_pages), int)): - logger.warning(u'random_wait_pages列表中的值应为整数类型,请重新输入') + logger.warning('random_wait_pages列表中的值应为整数类型,请重新输入') sys.exit() if min(random_wait_pages) < 1: - logger.warning(u'random_wait_pages列表中的值应大于0,请重新输入') + logger.warning('random_wait_pages列表中的值应大于0,请重新输入') sys.exit() # 验证random_wait_seconds random_wait_seconds = config['random_wait_seconds'] if not isinstance(random_wait_seconds, list): - logger.warning(u'random_wait_seconds参数值应为list类型,请重新输入') + logger.warning('random_wait_seconds参数值应为list类型,请重新输入') sys.exit() if (not isinstance(min(random_wait_seconds), int)) or (not isinstance( max(random_wait_seconds), int)): - logger.warning(u'random_wait_seconds列表中的值应为整数类型,请重新输入') + logger.warning('random_wait_seconds列表中的值应为整数类型,请重新输入') sys.exit() if min(random_wait_seconds) < 1: - logger.warning(u'random_wait_seconds列表中的值应大于0,请重新输入') + logger.warning('random_wait_seconds列表中的值应大于0,请重新输入') sys.exit() # 验证global_wait global_wait = config['global_wait'] if not isinstance(global_wait, list): - logger.warning(u'global_wait参数值应为list类型,请重新输入') + logger.warning('global_wait参数值应为list类型,请重新输入') sys.exit() for g in global_wait: if not isinstance(g, list): - logger.warning(u'global_wait参数内的值应为长度为2的list类型,请重新输入') + logger.warning('global_wait参数内的值应为长度为2的list类型,请重新输入') sys.exit() if len(g) != 2: - logger.warning(u'global_wait参数内的list长度应为2,请重新输入') + logger.warning('global_wait参数内的list长度应为2,请重新输入') sys.exit() for i in g: if (not isinstance(i, int)) or i < 1: - logger.warning(u'global_wait列表中的值应为大于0的整数,请重新输入') + logger.warning('global_wait列表中的值应为大于0的整数,请重新输入') sys.exit() # 验证write_mode write_mode = ['txt', 'csv', 'json', 'mongo', 'mysql', 'sqlite', 'kafka','post'] if not isinstance(config['write_mode'], list): - logger.warning(u'write_mode值应为list类型') + logger.warning('write_mode值应为list类型') sys.exit() for mode in config['write_mode']: if mode not in write_mode: logger.warning( - u'%s为无效模式,请从txt、csv、json、post、mongo、sqlite, kafka和mysql中挑选一个或多个作为write_mode', - mode) + f'{mode}为无效模式,请从txt、csv、json、post、mongo、sqlite, kafka和mysql中挑选一个或多个作为write_mode') sys.exit() # 验证user_id_list user_id_list = config['user_id_list'] if (not isinstance(user_id_list, list)) and (not user_id_list.endswith('.txt')): - logger.warning(u'user_id_list值应为list类型或txt文件路径') + logger.warning('user_id_list值应为list类型或txt文件路径') sys.exit() if not isinstance(user_id_list, list): - if not os.path.isabs(user_id_list): - user_id_list = os.getcwd() + os.sep + user_id_list - if not os.path.isfile(user_id_list): - logger.warning(u'不存在%s文件', user_id_list) + if not Path(user_id_list).is_absolute(): + user_id_list = str(Path.cwd() / user_id_list) + if not Path(user_id_list).is_file(): + logger.warning(f'不存在{user_id_list}文件') sys.exit() @@ -119,7 +118,7 @@ def get_user_config_list(file_name, default_since_date): lines = f.read().splitlines() lines = [line.decode('utf-8-sig') for line in lines] except UnicodeDecodeError: - logger.error(u'%s文件应为utf-8编码,请先将文件编码转为utf-8再运行程序', file_name) + logger.error(f'{file_name}文件应为utf-8编码,请先将文件编码转为utf-8再运行程序') sys.exit() user_config_list = [] for line in lines: @@ -143,7 +142,7 @@ def update_user_config_file(user_config_file_path, user_uri, nickname, start_time): """更新用户配置文件""" if not user_config_file_path: - user_config_file_path = os.getcwd() + os.sep + 'user_id_list.txt' + user_config_file_path = str(Path.cwd() / 'user_id_list.txt') with open(user_config_file_path, 'rb') as f: lines = f.read().splitlines() lines = [line.decode('utf-8-sig') for line in lines] @@ -169,8 +168,8 @@ def update_user_config_file(user_config_file_path, user_uri, nickname, def add_user_uri_list(user_config_file_path, user_uri_list): """向user_id_list.txt文件添加若干user_uri""" if not user_config_file_path: - user_config_file_path = os.getcwd() + os.sep + 'user_id_list.txt' - if os.path.isfile(user_config_file_path): + user_config_file_path = str(Path.cwd() / 'user_id_list.txt') + if Path(user_config_file_path).is_file(): user_uri_list[0] = '\n' + user_uri_list[0] with codecs.open(user_config_file_path, 'a', encoding='utf-8') as f: f.write('\n'.join(user_uri_list)) @@ -182,13 +181,13 @@ def get_cookie(): cookies_dict = {cookie.name: cookie.value for cookie in chrome_cookies} return cookies_dict except Exception as e: - logger.error(u'Failed to obtain weibo.cn cookie from Chrome browser: %s', str(e)) + logger.error(f'Failed to obtain weibo.cn cookie from Chrome browser: {e}') raise def update_cookie_config(cookie, user_config_file_path): """Update cookie in config.json""" if not user_config_file_path: - user_config_file_path = os.getcwd() + os.sep + 'config.json' + user_config_file_path = str(Path.cwd() / 'config.json') try: with codecs.open(user_config_file_path, 'r', encoding='utf-8') as f: config = json.load(f) @@ -200,7 +199,7 @@ def update_cookie_config(cookie, user_config_file_path): with codecs.open(user_config_file_path, 'w', encoding='utf-8') as f: json.dump(config, f, indent=4, ensure_ascii=False) except Exception as e: - logger.error(u'Failed to update cookie in config file: %s', str(e)) + logger.error(f'Failed to update cookie in config file: {e}') raise def check_cookie(user_config_file_path): @@ -213,5 +212,5 @@ def check_cookie(user_config_file_path): else: update_cookie_config(cookie, user_config_file_path) except Exception as e: - logger.error(u'Check for cookie failed: %s', str(e)) + logger.error(f'Check for cookie failed: {e}') raise diff --git a/weibo_spider/downloader/avatar_picture_downloader.py b/weibo_spider/downloader/avatar_picture_downloader.py index 3eb935ed..cac1e24e 100644 --- a/weibo_spider/downloader/avatar_picture_downloader.py +++ b/weibo_spider/downloader/avatar_picture_downloader.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from .img_downloader import ImgDownloader @@ -6,17 +6,17 @@ class AvatarPictureDownloader(ImgDownloader): def __init__(self, file_dir, file_download_timeout): super().__init__(file_dir, file_download_timeout) - self.describe = u'头像图片' + self.describe = '头像图片' self.key = 'avatar_pictures' async def handle_download(self, urls, session): """处理下载相关操作""" - file_dir = self.file_dir + os.sep + self.describe - if not os.path.isdir(file_dir): - os.makedirs(file_dir) + file_dir = self.file_dir / self.describe + if not file_dir.is_dir(): + file_dir.mkdir(parents=True, exist_ok=True) for i, url in enumerate(urls): index = url.rfind('/') file_name = url[index:] - file_path = file_dir + os.sep + file_name + file_path = file_dir / file_name await self.download_one_file(url, file_path, 'xxx', session) \ No newline at end of file diff --git a/weibo_spider/downloader/downloader.py b/weibo_spider/downloader/downloader.py index 75914e2f..b4b65618 100644 --- a/weibo_spider/downloader/downloader.py +++ b/weibo_spider/downloader/downloader.py @@ -1,10 +1,10 @@ # -*- coding: UTF-8 -*- import asyncio import logging -import os import sys import random from abc import ABC, abstractmethod +from pathlib import Path import aiohttp from tqdm import tqdm @@ -14,7 +14,7 @@ class Downloader(ABC): def __init__(self, file_dir, file_download_timeout): - self.file_dir = file_dir + self.file_dir = Path(file_dir) self.describe = '' self.key = '' self.file_download_timeout = [5, 5, 10] @@ -33,7 +33,8 @@ async def handle_download(self, urls, w, session): async def download_one_file(self, url, file_path, weibo_id, session): """下载单个文件(图片/视频)""" try: - if not os.path.isfile(file_path): + file_path = Path(file_path) + if not file_path.is_file(): # 随机延时,模拟人工操作 await asyncio.sleep(random.uniform(0.5, 1.5)) @@ -59,11 +60,11 @@ async def download_one_file(self, url, file_path, weibo_id, session): if last_exception: raise last_exception - return os.path.isfile(file_path) + return file_path.is_file() except Exception as e: - error_file = self.file_dir + os.sep + 'not_downloaded.txt' + error_file = self.file_dir / 'not_downloaded.txt' with open(error_file, 'ab') as f: - url = weibo_id + ':' + file_path + ':' + url + '\n' + url = f'{weibo_id}:{file_path}:{url}\n' f.write(url.encode(sys.stdout.encoding)) logger.exception(e) return False @@ -71,11 +72,11 @@ async def download_one_file(self, url, file_path, weibo_id, session): async def download_files(self, weibos, session): """下载文件(图片/视频)""" try: - logger.info(u'即将进行%s下载', self.describe) + logger.info(f'即将进行{self.describe}下载') for w in tqdm(weibos, desc='Download progress'): - if getattr(w, self.key) != u'无': + if getattr(w, self.key) != '无': await self.handle_download(getattr(w, self.key), w, session) - logger.info(u'%s下载完毕,保存路径:', self.describe) + logger.info(f'{self.describe}下载完毕,保存路径:') logger.info(self.file_dir) except Exception as e: logger.exception(e) diff --git a/weibo_spider/downloader/img_downloader.py b/weibo_spider/downloader/img_downloader.py index 61aabe01..cd863a2d 100644 --- a/weibo_spider/downloader/img_downloader.py +++ b/weibo_spider/downloader/img_downloader.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from .downloader import Downloader @@ -6,15 +6,15 @@ class ImgDownloader(Downloader): def __init__(self, file_dir, file_download_timeout): super().__init__(file_dir, file_download_timeout) - self.describe = u'图片' + self.describe = '图片' self.key = '' async def handle_download(self, urls, w, session): """处理下载相关操作""" - file_prefix = w.publish_time[:10].replace('-', '') + '_' + w.id - file_dir = self.file_dir + os.sep + self.describe - if not os.path.isdir(file_dir): - os.makedirs(file_dir) + file_prefix = f"{w.publish_time[:10].replace('-', '')}_{w.id}" + file_dir = self.file_dir / self.describe + if not file_dir.is_dir(): + file_dir.mkdir(parents=True, exist_ok=True) media_key = self.key or 'original_pictures' if ',' in urls: url_list = urls.split(',') @@ -24,8 +24,8 @@ async def handle_download(self, urls, w, session): file_suffix = '.jpg' else: file_suffix = url[index:] - file_name = file_prefix + '_' + str(i + 1) + file_suffix - file_path = file_dir + os.sep + file_name + file_name = f"{file_prefix}_{i + 1}{file_suffix}" + file_path = file_dir / file_name ok = await self.download_one_file(url, file_path, w.id, session) if ok: w.media.setdefault(media_key, []).append({ @@ -38,8 +38,8 @@ async def handle_download(self, urls, w, session): file_suffix = '.jpg' else: file_suffix = urls[index:] - file_name = file_prefix + file_suffix - file_path = file_dir + os.sep + file_name + file_name = f"{file_prefix}{file_suffix}" + file_path = file_dir / file_name ok = await self.download_one_file(urls, file_path, w.id, session) if ok: w.media.setdefault(media_key, []).append({ diff --git a/weibo_spider/downloader/video_downloader.py b/weibo_spider/downloader/video_downloader.py index 737794b5..77fbc111 100644 --- a/weibo_spider/downloader/video_downloader.py +++ b/weibo_spider/downloader/video_downloader.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from .downloader import Downloader @@ -6,15 +6,15 @@ class VideoDownloader(Downloader): def __init__(self, file_dir, file_download_timeout): super().__init__(file_dir, file_download_timeout) - self.describe = u'视频' + self.describe = '视频' self.key = 'video_url' async def handle_download(self, urls, w, session): """处理下载相关操作""" - file_prefix = w.publish_time[:10].replace('-', '') + '_' + w.id + file_prefix = f"{w.publish_time[:10].replace('-', '')}_{w.id}" file_suffix = '.mp4' - file_name = file_prefix + file_suffix - file_path = self.file_dir + os.sep + file_name + file_name = f"{file_prefix}{file_suffix}" + file_path = self.file_dir / file_name ok = await self.download_one_file(urls, file_path, w.id, session) if ok: w.media.setdefault('video', []).append({ diff --git a/weibo_spider/parser/index_parser.py b/weibo_spider/parser/index_parser.py index 4e9a7036..889130ed 100644 --- a/weibo_spider/parser/index_parser.py +++ b/weibo_spider/parser/index_parser.py @@ -11,7 +11,7 @@ class IndexParser(Parser): def __init__(self, cookie, user_uri, selector=None): self.cookie = cookie self.user_uri = user_uri - self.url = 'https://weibo.cn/%s/profile' % (user_uri) + self.url = f'https://weibo.cn/{user_uri}/profile' self.selector = selector if selector is not None else handle_html(self.cookie, self.url) def _get_user_id(self): @@ -19,7 +19,7 @@ def _get_user_id(self): user_id = self.user_uri url_list = self.selector.xpath("//div[@class='u']//a") for url in url_list: - if (url.xpath('string(.)')) == u'资料': + if (url.xpath('string(.)')) == '资料': if url.xpath('@href') and url.xpath('@href')[0].endswith( '/info'): link = url.xpath('@href')[0] @@ -49,7 +49,7 @@ async def get_user_async(self, session): user_id = self._get_user_id() from .util import handle_html_async # Local import if needed or top level - info_url = 'https://weibo.cn/%s/info' % (user_id) + info_url = f'https://weibo.cn/{user_id}/info' info_selector = await handle_html_async(self.cookie, info_url, session) self.user = InfoParser(self.cookie, diff --git a/weibo_spider/parser/info_parser.py b/weibo_spider/parser/info_parser.py index c4b597ff..cc5aaa01 100644 --- a/weibo_spider/parser/info_parser.py +++ b/weibo_spider/parser/info_parser.py @@ -11,7 +11,7 @@ class InfoParser(Parser): def __init__(self, cookie, user_id, selector=None): self.cookie = cookie - self.url = 'https://weibo.cn/%s/info' % (user_id) + self.url = f'https://weibo.cn/{user_id}/info' self.selector = selector if selector is not None else handle_html(self.cookie, self.url) def extract_user_info(self): @@ -20,13 +20,13 @@ def extract_user_info(self): user = User() nickname = self.selector.xpath('//title/text()')[0] nickname = nickname[:-3] - if nickname == u'登录 - 新' or nickname == u'新浪': - logger.warning(u'cookie错误或已过期,请按照README中方法重新获取') + if nickname == '登录 - 新' or nickname == '新浪': + logger.warning('cookie错误或已过期,请按照README中方法重新获取') sys.exit() user.nickname = nickname basic_info = self.selector.xpath("//div[@class='c'][3]/text()") - zh_list = [u'性别', u'地区', u'生日', u'简介', u'认证', u'达人'] + zh_list = ['性别', '地区', '生日', '简介', '认证', '达人'] en_list = [ 'gender', 'location', 'birthday', 'description', 'verified_reason', 'talent' @@ -37,19 +37,19 @@ def extract_user_info(self): i.split(':', 1)[1].replace('\u3000', '')) experienced = self.selector.xpath("//div[@class='tip'][2]/text()") - if experienced and experienced[0] == u'学习经历': + if experienced and experienced[0] == '学习经历': user.education = self.selector.xpath( "//div[@class='c'][4]/text()")[0][1:].replace( - u'\xa0', u' ') + '\xa0', ' ') if self.selector.xpath( - "//div[@class='tip'][3]/text()")[0] == u'工作经历': + "//div[@class='tip'][3]/text()")[0] == '工作经历': user.work = self.selector.xpath( "//div[@class='c'][5]/text()")[0][1:].replace( - u'\xa0', u' ') - elif experienced and experienced[0] == u'工作经历': + '\xa0', ' ') + elif experienced and experienced[0] == '工作经历': user.work = self.selector.xpath( "//div[@class='c'][4]/text()")[0][1:].replace( - u'\xa0', u' ') + '\xa0', ' ') return user except Exception as e: logger.exception(e) diff --git a/weibo_spider/parser/page_parser.py b/weibo_spider/parser/page_parser.py index 6bc28cf8..c9368a3c 100644 --- a/weibo_spider/parser/page_parser.py +++ b/weibo_spider/parser/page_parser.py @@ -27,7 +27,7 @@ def __init__(self, cookie, user_config, page, filter, selector=None, defer_fetch self.since_date = user_config['since_date'] self.end_date = user_config['end_date'] self.page = page - self.url = 'https://weibo.cn/%s/profile?page=%d' % (self.user_uri, page) + self.url = f'https://weibo.cn/{self.user_uri}/profile?page={page}' if self.end_date != 'now': since_date = self.since_date.split(' ')[0].split('-') end_date = self.end_date.split(' ')[0].split('-') @@ -37,8 +37,7 @@ def __init__(self, cookie, user_config, page, filter, selector=None, defer_fetch date[i] = '0' + date[i] starttime = ''.join(since_date) endtime = ''.join(end_date) - self.url = 'https://weibo.cn/%s/profile?starttime=%s&endtime=%s&advancedfilter=1&page=%d' % ( - self.user_uri, starttime, endtime, page) + self.url = f'https://weibo.cn/{self.user_uri}/profile?starttime={starttime}&endtime={endtime}&advancedfilter=1&page={page}' self.selector = selector self.to_continue = True is_exist = '' @@ -110,9 +109,9 @@ def get_original_weibo(self, info, weibo_id): """获取原创微博""" try: weibo_content = handle_garbled(info) - weibo_content = weibo_content[:weibo_content.rfind(u'赞')] + weibo_content = weibo_content[:weibo_content.rfind('赞')] a_text = info.xpath('div//a/text()') - if u'全文' in a_text: + if '全文' in a_text: wb_content = CommentParser(self.cookie, weibo_id).get_long_weibo() if wb_content: @@ -126,25 +125,22 @@ def get_retweet(self, info, weibo_id): try: weibo_content = handle_garbled(info) weibo_content = weibo_content[weibo_content.find(':') + - 1:weibo_content.rfind(u'赞')] - weibo_content = weibo_content[:weibo_content.rfind(u'赞')] + 1:weibo_content.rfind('赞')] + weibo_content = weibo_content[:weibo_content.rfind('赞')] a_text = info.xpath('div//a/text()') - if u'全文' in a_text: + if '全文' in a_text: wb_content = CommentParser(self.cookie, weibo_id).get_long_retweet() if wb_content: weibo_content = wb_content retweet_reason = handle_garbled(info.xpath('div')[-1]) - retweet_reason = retweet_reason[:retweet_reason.rindex(u'赞')] + retweet_reason = retweet_reason[:retweet_reason.rindex('赞')] original_user = info.xpath("div/span[@class='cmt']/a/text()") if original_user: original_user = original_user[0] - weibo_content = (retweet_reason + '\n' + u'原始用户: ' + - original_user + '\n' + u'转发内容: ' + - weibo_content) + weibo_content = (f'{retweet_reason}\n原始用户: {original_user}\n转发内容: {weibo_content}') else: - weibo_content = (retweet_reason + '\n' + u'转发内容: ' + - weibo_content) + weibo_content = (f'{retweet_reason}\n转发内容: {weibo_content}') return weibo_content except Exception as e: logger.exception(e) @@ -165,7 +161,7 @@ def get_article_url(self, info): """获取微博头条文章的url""" article_url = '' text = handle_garbled(info) - if text.startswith(u'发布了头条文章') or text.startswith(u'我发表了头条文章'): + if text.startswith('发布了头条文章') or text.startswith('我发表了头条文章'): url = info.xpath('.//a/@href') if url and url[0].startswith('https://weibo.com/ttarticle'): article_url = url[0] @@ -176,19 +172,19 @@ def get_publish_place(self, info): try: div_first = info.xpath('div')[0] a_list = div_first.xpath('a') - publish_place = u'无' + publish_place = '无' for a in a_list: if ('place.weibo.com' in a.xpath('@href')[0] - and a.xpath('text()')[0] == u'显示地图'): + and a.xpath('text()')[0] == '显示地图'): weibo_a = div_first.xpath("span[@class='ctt']/a") if len(weibo_a) >= 1: publish_place = weibo_a[-1] - if (u'视频' == div_first.xpath( + if ('视频' == div_first.xpath( "span[@class='ctt']/a/text()")[-1][-2:]): if len(weibo_a) >= 2: publish_place = weibo_a[-2] else: - publish_place = u'无' + publish_place = '无' publish_place = handle_garbled(publish_place) break return publish_place @@ -200,26 +196,26 @@ def get_publish_time(self, info): try: str_time = info.xpath("div/span[@class='ct']") str_time = handle_garbled(str_time[0]) - publish_time = str_time.split(u'来自')[0] - if u'刚刚' in publish_time: + publish_time = str_time.split('来自')[0] + if '刚刚' in publish_time: publish_time = datetime.now().strftime('%Y-%m-%d %H:%M') - elif u'分钟' in publish_time: - minute = publish_time[:publish_time.find(u'分钟')] + elif '分钟' in publish_time: + minute = publish_time[:publish_time.find('分钟')] minute = timedelta(minutes=int(minute)) publish_time = (datetime.now() - minute).strftime('%Y-%m-%d %H:%M') - elif u'今天' in publish_time: + elif '今天' in publish_time: today = datetime.now().strftime('%Y-%m-%d') time = publish_time[3:] - publish_time = today + ' ' + time + publish_time = f'{today} {time}' if len(publish_time) > 16: publish_time = publish_time[:16] - elif u'月' in publish_time: + elif '月' in publish_time: year = datetime.now().strftime('%Y') month = publish_time[0:2] day = publish_time[3:5] time = publish_time[7:12] - publish_time = year + '-' + month + '-' + day + ' ' + time + publish_time = f'{year}-{month}-{day} {time}' else: publish_time = publish_time[:16] return publish_time @@ -231,10 +227,10 @@ def get_publish_tool(self, info): try: str_time = info.xpath("div/span[@class='ct']") str_time = handle_garbled(str_time[0]) - if len(str_time.split(u'来自')) > 1: - publish_tool = str_time.split(u'来自')[1] + if len(str_time.split('来自')) > 1: + publish_tool = str_time.split('来自')[1] else: - publish_tool = u'无' + publish_tool = '无' return publish_tool except Exception as e: logger.exception(e) @@ -246,7 +242,7 @@ def get_weibo_footer(self, info): pattern = r'\d+' str_footer = info.xpath('div')[-1] str_footer = handle_garbled(str_footer) - str_footer = str_footer[str_footer.rfind(u'赞'):] + str_footer = str_footer[str_footer.rfind('赞'):] weibo_footer = re.findall(pattern, str_footer, re.M) up_num = int(weibo_footer[0]) @@ -270,14 +266,14 @@ def get_picture_urls(self, info, is_original): original_pictures = self.extract_picture_urls(info, weibo_id) picture_urls['original_pictures'] = original_pictures if not self.filter: - picture_urls['retweet_pictures'] = u'无' + picture_urls['retweet_pictures'] = '无' else: retweet_url = info.xpath("div/a[@class='cc']/@href")[0] retweet_id = retweet_url.split('/')[-1].split('?')[0] retweet_pictures = self.extract_picture_urls(info, retweet_id) picture_urls['retweet_pictures'] = retweet_pictures a_list = info.xpath('div[last()]/a/@href') - original_picture = u'无' + original_picture = '无' for a in a_list: if a.endswith(('.gif', '.jpeg', '.jpg', '.png')): original_picture = a @@ -289,13 +285,13 @@ def get_picture_urls(self, info, is_original): def get_video_url(self, info): """获取微博视频url""" - video_url = u'无' + video_url = '无' weibo_id = info.xpath('@id')[0][2:] try: video_page_url = '' a_text = info.xpath('./div[1]//a/text()') - if u'全文' in a_text: + if '全文' in a_text: video_page_url = CommentParser(self.cookie, weibo_id).get_video_page_url() else: @@ -328,7 +324,7 @@ def get_one_weibo(self, info): picture_urls = self.get_picture_urls(info, is_original) weibo.original_pictures = picture_urls[ 'original_pictures'] # 原创图片url - if weibo.original_pictures != u'无': + if weibo.original_pictures != '无': weibo.original_pictures_list = [ u.strip() for u in weibo.original_pictures.split(',') if u.strip() @@ -336,7 +332,7 @@ def get_one_weibo(self, info): if not self.filter: weibo.retweet_pictures = picture_urls[ 'retweet_pictures'] # 转发图片url - if weibo.retweet_pictures != u'无': + if weibo.retweet_pictures != '无': weibo.retweet_pictures_list = [ u.strip() for u in weibo.retweet_pictures.split(',') @@ -352,7 +348,7 @@ def get_one_weibo(self, info): weibo.comment_num = footer['comment_num'] # 评论数 else: weibo = None - logger.info(u'正在过滤转发微博') + logger.info('正在过滤转发微博') return weibo except Exception as e: logger.exception(e) @@ -361,9 +357,9 @@ def extract_picture_urls(self, info, weibo_id): """提取微博原始图片url""" try: a_list = info.xpath('div/a/@href') - first_pic = 'https://weibo.cn/mblog/pic/' + weibo_id - all_pic = 'https://weibo.cn/mblog/picAll/' + weibo_id - picture_urls = u'无' + first_pic = f'https://weibo.cn/mblog/pic/{weibo_id}' + all_pic = f'https://weibo.cn/mblog/picAll/{weibo_id}' + picture_urls = '无' if first_pic in ''.join(a_list): if all_pic in ''.join(a_list): preview_picture_list = MblogPicAllParser( @@ -386,11 +382,11 @@ def extract_picture_urls(self, info, weibo_id): break else: logger.warning( - u'爬虫微博可能被设置成了"不显示图片",请前往' - u'"https://weibo.cn/account/customize/pic",修改为"显示"' + '爬虫微博可能被设置成了"不显示图片",请前往' + '"https://weibo.cn/account/customize/pic",修改为"显示"' ) sys.exit() return picture_urls except Exception as e: logger.exception(e) - return u'无' + return '无' diff --git a/weibo_spider/parser/util.py b/weibo_spider/parser/util.py index f6238cf6..52c9ba14 100644 --- a/weibo_spider/parser/util.py +++ b/weibo_spider/parser/util.py @@ -2,6 +2,7 @@ import json import logging import sys +from pathlib import Path import aiohttp import requests @@ -28,13 +29,12 @@ async def handle_html_async(cookie, url, session): if GENERATE_TEST_DATA: import io - import os - resp_file = os.path.join(TEST_DATA_DIR, '%s.html' % hash_url(url)) + resp_file = str(Path(TEST_DATA_DIR) / f'{hash_url(url)}.html') with io.open(resp_file, 'wb') as f: f.write(content) - with io.open(os.path.join(TEST_DATA_DIR, URL_MAP_FILE), 'r+') as f: + with io.open(Path(TEST_DATA_DIR) / URL_MAP_FILE, 'r+') as f: url_map = json.loads(f.read()) url_map[url] = resp_file f.seek(0) @@ -56,13 +56,12 @@ def handle_html(cookie, url): if GENERATE_TEST_DATA: import io - import os - resp_file = os.path.join(TEST_DATA_DIR, '%s.html' % hash_url(url)) + resp_file = str(Path(TEST_DATA_DIR) / f'{hash_url(url)}.html') with io.open(resp_file, 'w', encoding='utf-8') as f: f.write(resp.text) - with io.open(os.path.join(TEST_DATA_DIR, URL_MAP_FILE), 'r+') as f: + with io.open(Path(TEST_DATA_DIR) / URL_MAP_FILE, 'r+') as f: url_map = json.loads(f.read()) url_map[url] = resp_file f.seek(0) @@ -83,12 +82,12 @@ def handle_garbled(info): else: info_str = str(info) # 若不支持 xpath,将其转换为字符串 - info = info_str.replace(u'\u200b', '').encode( + info = info_str.replace('\u200b', '').encode( sys.stdout.encoding, 'ignore').decode(sys.stdout.encoding) return info except Exception as e: logger.exception(e) - return u'无' + return '无' def bid2mid(bid): @@ -134,7 +133,7 @@ def to_video_download_url(cookie, video_page_url): if not video_url: # 说明该视频为直播 video_url = '' except json.decoder.JSONDecodeError: - logger.warning(u'当前账号没有浏览该视频的权限') + logger.warning('当前账号没有浏览该视频的权限') return video_url @@ -146,10 +145,10 @@ def string_to_int(string): return 0 if isinstance(string, int): return string - elif string.endswith(u'万+'): + elif string.endswith('万+'): string = string[:-2] + '0000' - elif string.endswith(u'万'): + elif string.endswith('万'): string = float(string[:-1]) * 10000 - elif string.endswith(u'亿'): + elif string.endswith('亿'): string = float(string[:-1]) * 100000000 return int(string) diff --git a/weibo_spider/spider.py b/weibo_spider/spider.py index f79b5bba..d128b894 100644 --- a/weibo_spider/spider.py +++ b/weibo_spider/spider.py @@ -9,10 +9,11 @@ import shutil import sys import asyncio -import aiohttp +from pathlib import Path from datetime import date, datetime, timedelta from time import sleep +import aiohttp from absl import app, flags from tqdm import tqdm @@ -29,8 +30,7 @@ flags.DEFINE_string('user_id_list', None, 'The path to user_id_list.txt.') flags.DEFINE_string('output_dir', None, 'The dir path to store results.') -logging_path = os.path.split( - os.path.realpath(__file__))[0] + os.sep + 'logging.conf' +logging_path = Path(__file__).parent / 'logging.conf' logging.config.fileConfig(logging_path) logger = logging.getLogger('spider') @@ -83,10 +83,10 @@ def __init__(self, config): if FLAGS.user_id_list: user_id_list = FLAGS.user_id_list if not isinstance(user_id_list, list): - if not os.path.isabs(user_id_list): - user_id_list = os.getcwd() + os.sep + user_id_list - if not os.path.isfile(user_id_list): - logger.warning('不存在%s文件', user_id_list) + if not Path(user_id_list).is_absolute(): + user_id_list = str(Path.cwd() / user_id_list) + if not Path(user_id_list).is_file(): + logger.warning(f'不存在{user_id_list}文件') sys.exit() self.user_config_file_path = user_id_list if FLAGS.u: @@ -142,7 +142,7 @@ def write_user(self, user): async def get_user_info(self, user_uri): """获取用户信息""" - url = 'https://weibo.cn/%s/profile' % (user_uri) + url = f'https://weibo.cn/{user_uri}/profile' selector = await handle_html_async(self.cookie, url, self.session) self.user = await IndexParser(self.cookie, user_uri, selector=selector).get_user_async(self.session) self.page_count += 1 @@ -167,7 +167,7 @@ async def get_weibo_info(self): if since_date <= now: # Async fetch page num user_uri = self.user_config['user_uri'] - url = 'https://weibo.cn/%s/profile' % (user_uri) + url = f'https://weibo.cn/{user_uri}/profile' selector = await handle_html_async(self.cookie, url, self.session) page_num = IndexParser(self.cookie, user_uri, selector=selector).get_page_num() @@ -177,7 +177,7 @@ async def get_weibo_info(self): wait_seconds = int( self.global_wait[0][1] * min(1, self.page_count / self.global_wait[0][0])) - logger.info(u'即将进入全局等待时间,%d秒后程序继续执行' % wait_seconds) + logger.info(f'即将进入全局等待时间,{wait_seconds}秒后程序继续执行') for i in tqdm(range(wait_seconds)): await asyncio.sleep(1) self.page_count = 0 @@ -204,12 +204,7 @@ async def get_weibo_info(self): weibos, self.weibo_id_list, to_continue = parser.get_one_page(self.weibo_id_list) logger.info( - u'%s已获取%s(%s)的第%d页微博%s', - '-' * 30, - self.user.nickname, - self.user.id, - page, - '-' * 30, + f"{'-' * 30}已获取{self.user.nickname}({self.user.id})的第{page}页微博{'-' * 30}" ) self.page_count += 1 if weibos: @@ -223,8 +218,7 @@ async def get_weibo_info(self): random_pages = random.randint(*self.random_wait_pages) if self.page_count >= self.global_wait[0][0]: - logger.info(u'即将进入全局等待时间,%d秒后程序继续执行' % - self.global_wait[0][1]) + logger.info(f'即将进入全局等待时间,{self.global_wait[0][1]}秒后程序继续执行') for i in tqdm(range(self.global_wait[0][1])): await asyncio.sleep(1) self.page_count = 0 @@ -247,16 +241,16 @@ def _get_filepath(self, type): if self.result_dir_name: dir_name = self.user.id if FLAGS.output_dir is not None: - file_dir = FLAGS.output_dir + os.sep + dir_name + file_dir = Path(FLAGS.output_dir) / dir_name else: - file_dir = (os.getcwd() + os.sep + 'weibo' + os.sep + dir_name) + file_dir = Path.cwd() / 'weibo' / dir_name if type == 'img' or type == 'video': - file_dir = file_dir + os.sep + type - if not os.path.isdir(file_dir): - os.makedirs(file_dir) + file_dir = file_dir / type + if not file_dir.is_dir(): + file_dir.mkdir(parents=True, exist_ok=True) if type == 'img' or type == 'video': return file_dir - file_path = file_dir + os.sep + self.user.id + '.' + type + file_path = file_dir / f'{self.user.id}.{type}' return file_path except Exception as e: logger.exception(e) @@ -347,10 +341,10 @@ async def get_one_user(self, user_config): await self.write_weibo(weibos) self.got_num += len(weibos) if not self.filter: - logger.info(u'共爬取' + str(self.got_num) + u'条微博') + logger.info(f'共爬取{self.got_num}条微博') else: - logger.info(u'共爬取' + str(self.got_num) + u'条原创微博') - logger.info(u'信息抓取完毕') + logger.info(f'共爬取{self.got_num}条原创微博') + logger.info('信息抓取完毕') logger.info('*' * 100) except Exception as e: logger.exception(e) @@ -360,7 +354,7 @@ async def start(self): try: if not self.user_config_list: logger.info( - u'没有配置有效的user_id,请通过config.json或user_id_list.txt配置user_id') + '没有配置有效的user_id,请通过config.json或user_id_list.txt配置user_id') return async with aiohttp.ClientSession() as session: @@ -381,17 +375,15 @@ async def start(self): def _get_config(): """获取config.json数据""" - src = os.path.split( - os.path.realpath(__file__))[0] + os.sep + 'config_sample.json' - config_path = os.getcwd() + os.sep + 'config.json' + src = Path(__file__).parent / 'config_sample.json' + config_path = Path.cwd() / 'config.json' if FLAGS.config_path: - config_path = FLAGS.config_path - elif not os.path.isfile(config_path): + config_path = Path(FLAGS.config_path) + elif not config_path.is_file(): shutil.copy(src, config_path) - logger.info(u'请先配置当前目录(%s)下的config.json文件,' - u'如果想了解config.json参数的具体意义及配置方法,请访问\n' - u'https://github.com/dataabc/weiboSpider#2程序设置' % - os.getcwd()) + logger.info(f'请先配置当前目录({Path.cwd()})下的config.json文件,' + '如果想了解config.json参数的具体意义及配置方法,请访问\n' + 'https://github.com/dataabc/weiboSpider#2程序设置') sys.exit() try: with open(config_path) as f: @@ -402,8 +394,8 @@ def _get_config(): config = json.loads(f.read()) return config except ValueError: - logger.error(u'config.json 格式不正确,请访问 ' - u'https://github.com/dataabc/weiboSpider#2程序设置') + logger.error('config.json 格式不正确,请访问 ' + 'https://github.com/dataabc/weiboSpider#2程序设置') sys.exit() async def async_main(_): diff --git a/weibo_spider/user.py b/weibo_spider/user.py index 9820fdf1..12de454d 100644 --- a/weibo_spider/user.py +++ b/weibo_spider/user.py @@ -31,9 +31,9 @@ def to_dict(self): def __str__(self): """打印微博用户信息""" result = '' - result += u'用户昵称: %s\n' % self.nickname - result += u'用户id: %s\n' % self.id - result += u'微博数: %d\n' % self.weibo_num - result += u'关注数: %d\n' % self.following - result += u'粉丝数: %d\n' % self.followers + result += f'用户昵称: {self.nickname}\n' + result += f'用户id: {self.id}\n' + result += f'微博数: {self.weibo_num}\n' + result += f'关注数: {self.following}\n' + result += f'粉丝数: {self.followers}\n' return result diff --git a/weibo_spider/weibo.py b/weibo_spider/weibo.py index b1d3c2da..aa5e6604 100644 --- a/weibo_spider/weibo.py +++ b/weibo_spider/weibo.py @@ -32,11 +32,11 @@ def __init__(self): def __str__(self): """打印一条微博""" result = self.content + '\n' - result += u'微博发布位置:%s\n' % self.publish_place - result += u'发布时间:%s\n' % self.publish_time - result += u'发布工具:%s\n' % self.publish_tool - result += u'点赞数:%d\n' % self.up_num - result += u'转发数:%d\n' % self.retweet_num - result += u'评论数:%d\n' % self.comment_num - result += u'url:https://weibo.cn/comment/%s\n' % self.id + result += f'微博发布位置:{self.publish_place}\n' + result += f'发布时间:{self.publish_time}\n' + result += f'发布工具:{self.publish_tool}\n' + result += f'点赞数:{self.up_num}\n' + result += f'转发数:{self.retweet_num}\n' + result += f'评论数:{self.comment_num}\n' + result += f'url:https://weibo.cn/comment/{self.id}\n' return result diff --git a/weibo_spider/writer/csv_writer.py b/weibo_spider/writer/csv_writer.py index 646ea4a1..30a5d8c5 100644 --- a/weibo_spider/writer/csv_writer.py +++ b/weibo_spider/writer/csv_writer.py @@ -41,6 +41,6 @@ def write_weibo(self, weibos): newline='') as f: writer = csv.writer(f) writer.writerows(result_data) - logger.info(u'%d条微博写入csv文件完毕,保存路径:%s', len(weibos), self.file_path) + logger.info(f'{len(weibos)}条微博写入csv文件完毕,保存路径:{self.file_path}') except Exception as e: logger.exception(e) diff --git a/weibo_spider/writer/json_writer.py b/weibo_spider/writer/json_writer.py index 924b097b..026e3c75 100644 --- a/weibo_spider/writer/json_writer.py +++ b/weibo_spider/writer/json_writer.py @@ -1,7 +1,7 @@ import codecs import json import logging -import os +from pathlib import Path from .writer import Writer @@ -43,10 +43,10 @@ def _update_json_data(self, data, weibo_info): def write_weibo(self, weibos): """将爬到的信息写入json文件""" data = {} - if os.path.isfile(self.file_path): + if Path(self.file_path).is_file(): with codecs.open(self.file_path, 'r', encoding='utf-8') as f: data = json.load(f) data = self._update_json_data(data, [w.to_dict() for w in weibos]) with codecs.open(self.file_path, 'w', encoding='utf-8') as f: f.write(json.dumps(data, indent=4, ensure_ascii=False)) - logger.info(u'%d条微博写入json文件完毕,保存路径:%s', len(weibos), self.file_path) + logger.info(f'{len(weibos)}条微博写入json文件完毕,保存路径:{self.file_path}') diff --git a/weibo_spider/writer/kafka_writer.py b/weibo_spider/writer/kafka_writer.py index e157ad9d..ed4d6f69 100644 --- a/weibo_spider/writer/kafka_writer.py +++ b/weibo_spider/writer/kafka_writer.py @@ -13,7 +13,7 @@ def __init__(self, kafka_config): from kafka import KafkaProducer except ImportError: logger.warning( - u'系统中可能没有安装kafka库,请先运行 pip install kafka-python ,再运行程序') + '系统中可能没有安装kafka库,请先运行 pip install kafka-python ,再运行程序') sys.exit() self.kafka_config = kafka_config @@ -23,7 +23,7 @@ def __init__(self, kafka_config): ).encode('UTF-8')) self.weibo_topics = list(kafka_config['weibo_topics']) self.user_topics = list(kafka_config['user_topics']) - logger.info('{}', kafka_config) + logger.info(f'{kafka_config}') def write_weibo(self, weibo): for w in weibo: diff --git a/weibo_spider/writer/mongo_writer.py b/weibo_spider/writer/mongo_writer.py index 763dfb72..2fd5714c 100644 --- a/weibo_spider/writer/mongo_writer.py +++ b/weibo_spider/writer/mongo_writer.py @@ -20,7 +20,7 @@ def _info_to_mongodb(self, collection, info_list): import pymongo except ImportError: logger.warning( - u'系统中可能没有安装pymongo库,请先运行 pip install pymongo ,再运行程序') + '系统中可能没有安装pymongo库,请先运行 pip install pymongo ,再运行程序') sys.exit() try: from pymongo import MongoClient @@ -42,7 +42,7 @@ def _info_to_mongodb(self, collection, info_list): collection.update_one({'id': info['id']}, {'$set': info}) except pymongo.errors.ServerSelectionTimeoutError: logger.warning( - u'系统中可能没有安装或启动MongoDB数据库,请先根据系统环境安装或启动MongoDB,再运行程序') + '系统中可能没有安装或启动MongoDB数据库,请先根据系统环境安装或启动MongoDB,再运行程序') sys.exit() def write_weibo(self, weibos): @@ -52,11 +52,11 @@ def write_weibo(self, weibos): w.user_id = self.user.id weibo_list.append(w.to_dict()) self._info_to_mongodb('weibo', weibo_list) - logger.info(u'%d条微博写入MongoDB数据库完毕', len(weibos)) + logger.info(f'{len(weibos)}条微博写入MongoDB数据库完毕') def write_user(self, user): """将爬取的用户信息写入MongoDB数据库""" self.user = user user_list = [user.to_dict()] self._info_to_mongodb('user', user_list) - logger.info(u'%s信息写入MongoDB数据库完毕', user.nickname) + logger.info(f'{user.nickname}信息写入MongoDB数据库完毕') diff --git a/weibo_spider/writer/mysql_writer.py b/weibo_spider/writer/mysql_writer.py index 1d631b3e..7ae2473e 100644 --- a/weibo_spider/writer/mysql_writer.py +++ b/weibo_spider/writer/mysql_writer.py @@ -31,13 +31,13 @@ def _mysql_create_database(self, sql): import pymysql except ImportError: logger.warning( - u'系统中可能没有安装pymysql库,请先运行 pip install pymysql ,再运行程序') + '系统中可能没有安装pymysql库,请先运行 pip install pymysql ,再运行程序') sys.exit() try: connection = pymysql.connect(**self.mysql_config) self._mysql_create(connection, sql) except pymysql.OperationalError: - logger.warning(u'系统中可能没有安装或正确配置MySQL数据库,请先根据系统环境安装或配置MySQL,再运行程序') + logger.warning('系统中可能没有安装或正确配置MySQL数据库,请先根据系统环境安装或配置MySQL,再运行程序') sys.exit() def _mysql_create_table(self, sql): @@ -108,7 +108,7 @@ def write_weibo(self, weibos): weibo.user_id = self.user.id weibo_list.append(weibo.to_dict()) self._mysql_insert('weibo', weibo_list) - logger.info(u'%d条微博写入MySQL数据库完毕', len(weibos)) + logger.info(f'{len(weibos)}条微博写入MySQL数据库完毕') except Exception as e: logger.exception(e) @@ -137,6 +137,6 @@ def write_user(self, user): ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""" self._mysql_create_table(create_table) self._mysql_insert('user', [user.to_dict()]) - logger.info(u'%s信息写入MySQL数据库完毕', user.nickname) + logger.info(f'{user.nickname}信息写入MySQL数据库完毕') except Exception as e: logger.exception(e) diff --git a/weibo_spider/writer/post_writer.py b/weibo_spider/writer/post_writer.py index ef5b75d3..f6c54aac 100644 --- a/weibo_spider/writer/post_writer.py +++ b/weibo_spider/writer/post_writer.py @@ -54,6 +54,6 @@ def write_weibo(self, weibos): data = self._update_json_data(data, [w.to_dict() for w in weibos]) if data: self.send_post_request_with_token(self.api_url, data, self.api_token, 3, 2) - logger.info(u'%d条微博通过POST发送到 %s', len(weibos), self.api_url) + logger.info(f'{len(weibos)}条微博通过POST发送到 {self.api_url}') else: - logger.info(u'没有获取到微博,略过API POST') + logger.info('没有获取到微博,略过API POST') diff --git a/weibo_spider/writer/sqlite_writer.py b/weibo_spider/writer/sqlite_writer.py index ced3e840..e9c29e10 100644 --- a/weibo_spider/writer/sqlite_writer.py +++ b/weibo_spider/writer/sqlite_writer.py @@ -79,7 +79,7 @@ def write_weibo(self, weibos): weibo.user_id = self.user.id weibo_list.append(weibo.to_dict()) self._sqlite_insert('weibo', weibo_list) - logger.info(u'%d条微博写入sqlite数据库完毕', len(weibos)) + logger.info(f'{len(weibos)}条微博写入sqlite数据库完毕') def write_user(self, user): """将爬取的用户信息写入sqlite数据库""" @@ -105,4 +105,4 @@ def write_user(self, user): )""" self._sqlite_create_table(create_table) self._sqlite_insert('user', [user.to_dict()]) - logger.info(u'%s信息写入sqlite数据库完毕', user.nickname) + logger.info(f'{user.nickname}信息写入sqlite数据库完毕') diff --git a/weibo_spider/writer/txt_writer.py b/weibo_spider/writer/txt_writer.py index 2b708070..a6597052 100644 --- a/weibo_spider/writer/txt_writer.py +++ b/weibo_spider/writer/txt_writer.py @@ -10,15 +10,15 @@ class TxtWriter(Writer): def __init__(self, file_path, filter): self.file_path = file_path - self.user_header = u'用户信息' + self.user_header = '用户信息' self.user_desc = [('nickname', '用户昵称'), ('id', '用户id'), ('weibo_num', '微博数'), ('following', '关注数'), ('followers', '粉丝数')] if filter: - self.weibo_header = u'原创微博内容' + self.weibo_header = '原创微博内容' else: - self.weibo_header = u'微博内容' + self.weibo_header = '微博内容' self.weibo_desc = [('publish_place', '微博位置'), ('publish_time', '发布时间'), ('up_num', '点赞数'), ('retweet_num', '转发数'), ('comment_num', '评论数'), ('publish_tool', '发布工具')] @@ -31,8 +31,7 @@ def write_user(self, user): with open(self.file_path, 'ab') as f: f.write((self.user_header + ':\n' + user_info + '\n\n').encode( sys.stdout.encoding)) - logger.info(u'%s信息写入txt文件完毕,保存路径:%s', self.user.nickname, - self.file_path) + logger.info(f'{self.user.nickname}信息写入txt文件完毕,保存路径:{self.file_path}') def write_weibo(self, weibo): """将爬取的信息写入txt文件""" @@ -52,6 +51,6 @@ def write_weibo(self, weibo): with open(self.file_path, 'ab') as f: f.write((weibo_header + result).encode(sys.stdout.encoding)) - logger.info(u'%d条微博写入txt文件完毕,保存路径:%s', len(weibo), self.file_path) + logger.info(f'{len(weibo)}条微博写入txt文件完毕,保存路径:{self.file_path}') except Exception as e: logger.exception(e) From c5a4b898dab4cf0094930d3fe776d3471af16263 Mon Sep 17 00:00:00 2001 From: wchiways Date: Thu, 5 Feb 2026 00:02:32 +0800 Subject: [PATCH 2/5] refactor: use dataclasses for User/Weibo and introduce SpiderConfig --- tests/test_spider_init.py | 28 +++++++ weibo_spider/config.py | 53 +++++++++++++ weibo_spider/downloader/__init__.py | 5 +- weibo_spider/spider.py | 117 ++++++++++++++-------------- weibo_spider/user.py | 46 +++++------ weibo_spider/weibo.py | 52 ++++++------- weibo_spider/writer/__init__.py | 6 +- 7 files changed, 190 insertions(+), 117 deletions(-) create mode 100644 tests/test_spider_init.py create mode 100644 weibo_spider/config.py diff --git a/tests/test_spider_init.py b/tests/test_spider_init.py new file mode 100644 index 00000000..7738303c --- /dev/null +++ b/tests/test_spider_init.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import MagicMock, patch +from weibo_spider.spider import Spider +from weibo_spider.config import SpiderConfig + +class TestSpiderInit(unittest.TestCase): + @patch('weibo_spider.spider.FLAGS') + def test_init(self, mock_flags): + # Mock flags + mock_flags.user_id_list = None + mock_flags.u = None + mock_flags.output_dir = None + + config_dict = { + 'user_id_list': ['123'], + 'filter': 1, + 'since_date': '2023-01-01', + 'write_mode': ['csv'], + 'cookie': 'cookie' + } + config = SpiderConfig(**config_dict) + spider = Spider(config) + self.assertEqual(spider.filter, 1) + self.assertEqual(spider.since_date, '2023-01-01') + self.assertIsInstance(spider.config, SpiderConfig) + +if __name__ == '__main__': + unittest.main() diff --git a/weibo_spider/config.py b/weibo_spider/config.py new file mode 100644 index 00000000..59a1956d --- /dev/null +++ b/weibo_spider/config.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Union, Dict, Any +from .config_util import validate_config + +@dataclass +class SpiderConfig: + """微博爬虫配置类 + + Attributes: + user_id_list: 要爬取的微博用户ID列表,可以是ID列表,也可以是包含ID的字典列表,或txt文件路径。 + cookie: 微博的Cookie,用于身份验证。 + filter: 过滤类型,0表示抓取全部微博,1表示只抓取原创微博。 + since_date: 抓取起始时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或整数(表示从今天起的前n天)。 + end_date: 抓取结束时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或"now"表示当前时间。 + random_wait_pages: 随机等待页数范围 [min, max],每爬取多少页暂停一次。 + random_wait_seconds: 随机等待时间范围 [min, max](秒),每次暂停等待的时长。 + global_wait: 全局等待配置 [[页数, 秒数], ...],例如每爬取1000页等待3600秒。 + write_mode: 结果保存类型列表,可包含 "txt", "csv", "json", "mongo", "mysql", "sqlite", "kafka", "post"。 + pic_download: 是否下载微博图片,0不下载,1下载。 + video_download: 是否下载微博视频,0不下载,1下载。 + file_download_timeout: 文件下载超时设置 [重试次数, 连接超时, 读取超时]。 + result_dir_name: 结果目录命名方式,0使用用户昵称,1使用用户ID。 + mysql_config: MySQL数据库连接配置字典。 + sqlite_config: SQLite数据库连接路径。 + kafka_config: Kafka配置字典。 + mongo_config: MongoDB配置字典。 + post_config: POST请求配置字典(用于数据推送)。 + """ + user_id_list: Union[List[Union[str, Dict[str, str]]], str] + cookie: str + filter: int = 0 + since_date: Union[int, str] = 0 + end_date: str = "now" + random_wait_pages: List[int] = field(default_factory=lambda: [1, 5]) + random_wait_seconds: List[int] = field(default_factory=lambda: [6, 10]) + global_wait: List[List[int]] = field(default_factory=lambda: [[1000, 3600]]) + write_mode: List[str] = field(default_factory=lambda: ["csv"]) + pic_download: int = 0 + video_download: int = 0 + file_download_timeout: List[int] = field(default_factory=lambda: [5, 5, 10]) + result_dir_name: int = 0 + mysql_config: Optional[Dict[str, Any]] = None + sqlite_config: Optional[str] = None + kafka_config: Optional[Dict[str, Any]] = None + mongo_config: Optional[Dict[str, Any]] = None + post_config: Optional[Dict[str, Any]] = None + + def __post_init__(self): + """初始化后的验证逻辑""" + # 使用现有的 validate_config 函数进行校验 + # 将 dataclass 转换为字典以适配旧有的验证函数 + config_dict = self.__dict__ + validate_config(config_dict) \ No newline at end of file diff --git a/weibo_spider/downloader/__init__.py b/weibo_spider/downloader/__init__.py index 53e9dfdf..ed339dc8 100644 --- a/weibo_spider/downloader/__init__.py +++ b/weibo_spider/downloader/__init__.py @@ -1,9 +1,10 @@ +from .downloader import Downloader from .origin_picture_downloader import OriginPictureDownloader from .retweet_picture_downloader import RetweetPictureDownloader from .avatar_picture_downloader import AvatarPictureDownloader from .video_downloader import VideoDownloader __all__ = [ - OriginPictureDownloader, RetweetPictureDownloader, AvatarPictureDownloader, - VideoDownloader + Downloader, OriginPictureDownloader, RetweetPictureDownloader, + AvatarPictureDownloader, VideoDownloader ] diff --git a/weibo_spider/spider.py b/weibo_spider/spider.py index d128b894..aa771f63 100644 --- a/weibo_spider/spider.py +++ b/weibo_spider/spider.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: UTF-8 -*- +from typing import Dict, Any, List, Optional import json import logging import logging.config @@ -18,10 +19,12 @@ from tqdm import tqdm from . import config_util, datetime_util -from .downloader import AvatarPictureDownloader +from .config import SpiderConfig +from .downloader import AvatarPictureDownloader, Downloader from .parser import AlbumParser, IndexParser, PageParser, PhotoParser from .parser.util import handle_html_async from .user import User +from .writer import Writer FLAGS = flags.FLAGS @@ -36,52 +39,44 @@ class Spider: - def __init__(self, config): + def __init__(self, config: SpiderConfig) -> None: """Weibo类初始化""" - self.filter = config[ - 'filter'] # 取值范围为0、1,程序默认值为0,代表要爬取用户的全部微博,1代表只爬取用户的原创微博 - since_date = config['since_date'] + self.config = config + self.filter: int = config.filter + since_date = config.since_date if isinstance(since_date, int): since_date = date.today() - timedelta(since_date) - self.since_date = str( + self.since_date: str = str( since_date) # 起始时间,即爬取发布日期从该值到结束时间的微博,形式为yyyy-mm-dd - self.end_date = config[ - 'end_date'] # 结束时间,即爬取发布日期从起始时间到该值的微博,形式为yyyy-mm-dd,特殊值"now"代表现在 - random_wait_pages = config['random_wait_pages'] - self.random_wait_pages = [ - min(random_wait_pages), - max(random_wait_pages) - ] # 随机等待频率,即每爬多少页暂停一次 - random_wait_seconds = config['random_wait_seconds'] - self.random_wait_seconds = [ - min(random_wait_seconds), - max(random_wait_seconds) - ] # 随机等待时间,即每次暂停要sleep多少秒 - self.global_wait = config['global_wait'] # 配置全局等待时间,如每爬1000页等待3600秒等 - self.page_count = 0 # 统计每次全局等待后,爬取了多少页,若页数满足全局等待要求就进入下一次全局等待 - self.write_mode = config[ - 'write_mode'] # 结果信息保存类型,为list形式,可包含txt、csv、json、mongo和mysql五种类型 - self.pic_download = config[ - 'pic_download'] # 取值范围为0、1,程序默认值为0,代表不下载微博原始图片,1代表下载 - self.video_download = config[ - 'video_download'] # 取值范围为0、1,程序默认为0,代表不下载微博视频,1代表下载 - self.file_download_timeout = config.get( - 'file_download_timeout', - [5, 5, 10 - ]) # 控制文件下载“超时”时的操作,值是list形式,包含三个数字,依次分别是最大超时重试次数、最大连接时间和最大读取时间 - self.result_dir_name = config.get( - 'result_dir_name', 0) # 结果目录名,取值为0或1,决定结果文件存储在用户昵称文件夹里还是用户id文件夹里 - self.cookie = config['cookie'] - self.mysql_config = config.get('mysql_config') # MySQL数据库连接配置,可以不填 - - self.sqlite_config = config.get('sqlite_config') - self.kafka_config = config.get('kafka_config') - self.mongo_config = config.get('mongo_config') - self.post_config = config.get('post_config') - self.user_config_file_path = '' - user_id_list = config['user_id_list'] + self.end_date: str = str(config.end_date) # 结束时间,即爬取发布日期从起始时间到该值的微博,形式为yyyy-mm-dd,特殊值"now"代表现在 + + self.random_wait_pages: List[int] = [ + min(config.random_wait_pages), + max(config.random_wait_pages) + ] + self.random_wait_seconds: List[int] = [ + min(config.random_wait_seconds), + max(config.random_wait_seconds) + ] + self.global_wait: List[List[int]] = config.global_wait + self.page_count: int = 0 + self.write_mode: List[str] = config.write_mode + self.pic_download: int = config.pic_download + self.video_download: int = config.video_download + self.file_download_timeout: List[int] = config.file_download_timeout + self.result_dir_name: int = config.result_dir_name + self.cookie: str = config.cookie + self.mysql_config: Optional[Dict[str, Any]] = config.mysql_config + self.sqlite_config: Optional[str] = config.sqlite_config + self.kafka_config: Optional[Dict[str, Any]] = config.kafka_config + self.mongo_config: Optional[Dict[str, Any]] = config.mongo_config + self.post_config: Optional[Dict[str, Any]] = config.post_config + + self.user_config_file_path: str = '' + user_id_list = config.user_id_list if FLAGS.user_id_list: user_id_list = FLAGS.user_id_list + if not isinstance(user_id_list, list): if not Path(user_id_list).is_absolute(): user_id_list = str(Path.cwd() / user_id_list) @@ -120,34 +115,37 @@ def __init__(self, config): user_id_list, self.since_date) for user_config in user_config_list: user_config['end_date'] = self.end_date - self.user_config_list = user_config_list # 要爬取的微博用户的user_config列表 - self.user_config = {} # 用户配置,包含用户id和since_date - self.new_since_date = '' # 完成某用户爬取后,自动生成对应用户新的since_date - self.user = User() # 存储爬取到的用户信息 - self.got_num = 0 # 存储爬取到的微博数 - self.weibo_id_list = [] # 存储爬取到的所有微博id - self.session = None # aiohttp session - - async def write_weibo(self, weibos): + self.user_config_list: List[Dict[str, str]] = user_config_list # 要爬取的微博用户的user_config列表 + self.user_config: Dict[str, str] = {} # 用户配置,包含用户id和since_date + self.new_since_date: str = '' # 完成某用户爬取后,自动生成对应用户新的since_date + self.user: User = User() # 存储爬取到的用户信息 + self.got_num: int = 0 # 存储爬取到的微博数 + self.weibo_id_list: List[str] = [] # 存储爬取到的所有微博id + self.session: Optional[aiohttp.ClientSession] = None # aiohttp session + + self.writers: List[Writer] = [] + self.downloaders: List[Downloader] = [] + + async def write_weibo(self, weibos: List[Any]) -> None: """将爬取到的信息写入文件或数据库""" for downloader in self.downloaders: await downloader.download_files(weibos, self.session) for writer in self.writers: writer.write_weibo(weibos) - def write_user(self, user): + def write_user(self, user: User) -> None: """将用户信息写入数据库""" for writer in self.writers: writer.write_user(user) - async def get_user_info(self, user_uri): + async def get_user_info(self, user_uri: str) -> None: """获取用户信息""" url = f'https://weibo.cn/{user_uri}/profile' selector = await handle_html_async(self.cookie, url, self.session) self.user = await IndexParser(self.cookie, user_uri, selector=selector).get_user_async(self.session) self.page_count += 1 - async def download_user_avatar(self, user_uri): + async def download_user_avatar(self, user_uri: str) -> None: """下载用户头像""" # Note: This remains synchronous for now as it's a minor part of the flow avatar_album_url = PhotoParser(self.cookie, @@ -234,7 +232,7 @@ async def get_weibo_info(self): except Exception as e: logger.exception(e) - def _get_filepath(self, type): + def _get_filepath(self, type: str) -> Path: """获取结果文件路径""" try: dir_name = self.user.nickname @@ -254,8 +252,9 @@ def _get_filepath(self, type): return file_path except Exception as e: logger.exception(e) + return Path() # Return empty path on error to match signature - def initialize_info(self, user_config): + def initialize_info(self, user_config: Dict[str, str]) -> None: """初始化爬虫信息""" self.got_num = 0 self.user_config = user_config @@ -322,7 +321,7 @@ def initialize_info(self, user_config): VideoDownloader(self._get_filepath('video'), self.file_download_timeout)) - async def get_one_user(self, user_config): + async def get_one_user(self, user_config: Dict[str, str]) -> None: """获取一个用户的微博""" try: await self.get_user_info(user_config['user_uri']) @@ -349,7 +348,7 @@ async def get_one_user(self, user_config): except Exception as e: logger.exception(e) - async def start(self): + async def start(self) -> None: """运行爬虫""" try: if not self.user_config_list: @@ -400,8 +399,8 @@ def _get_config(): async def async_main(_): try: - config = _get_config() - config_util.validate_config(config) + config_dict = _get_config() + config = SpiderConfig(**config_dict) wb = Spider(config) await wb.start() # 爬取微博信息 except Exception as e: diff --git a/weibo_spider/user.py b/weibo_spider/user.py index 12de454d..20efb5b9 100644 --- a/weibo_spider/user.py +++ b/weibo_spider/user.py @@ -1,34 +1,26 @@ -class User: - __slots__ = ( - 'id', 'nickname', 'gender', 'location', 'birthday', 'description', - 'verified_reason', 'talent', 'education', 'work', 'weibo_num', - 'following', 'followers' - ) - - def __init__(self): - self.id = '' - - self.nickname = '' +from dataclasses import dataclass, asdict - self.gender = '' - self.location = '' - self.birthday = '' - self.description = '' - self.verified_reason = '' - self.talent = '' - - self.education = '' - self.work = '' - - self.weibo_num = 0 - self.following = 0 - self.followers = 0 +@dataclass(slots=True) +class User: + id: str = '' + nickname: str = '' + gender: str = '' + location: str = '' + birthday: str = '' + description: str = '' + verified_reason: str = '' + talent: str = '' + education: str = '' + work: str = '' + weibo_num: int = 0 + following: int = 0 + followers: int = 0 - def to_dict(self): + def to_dict(self) -> dict: """将对象转换为字典""" - return {slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)} + return asdict(self) - def __str__(self): + def __str__(self) -> str: """打印微博用户信息""" result = '' result += f'用户昵称: {self.nickname}\n' diff --git a/weibo_spider/weibo.py b/weibo_spider/weibo.py index aa5e6604..bb92f9b7 100644 --- a/weibo_spider/weibo.py +++ b/weibo_spider/weibo.py @@ -1,35 +1,31 @@ +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any + +@dataclass(slots=True) class Weibo: - __slots__ = ( - 'id', 'user_id', 'content', 'article_url', 'original_pictures', - 'retweet_pictures', 'original', 'video_url', 'original_pictures_list', - 'retweet_pictures_list', 'media', 'publish_place', 'publish_time', - 'publish_tool', 'up_num', 'retweet_num', 'comment_num' - ) + id: str = '' + user_id: str = '' + content: str = '' + article_url: str = '' + original_pictures: str = '' # Usually a string of URLs separated by comma or '无' + retweet_pictures: str = '' # Usually a string of URLs separated by comma or '无' + original: bool = True + video_url: str = '' + original_pictures_list: List[str] = field(default_factory=list) + retweet_pictures_list: List[str] = field(default_factory=list) + media: Dict[str, Any] = field(default_factory=dict) # Assuming media can be complex + publish_place: str = '' + publish_time: str = '' + publish_tool: str = '' + up_num: int = 0 + retweet_num: int = 0 + comment_num: int = 0 - def to_dict(self): + def to_dict(self) -> dict: """将对象转换为字典""" - return {slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)} - - def __init__(self): - self.id = '' - self.user_id = '' - self.content = '' - self.article_url = '' - self.original_pictures = [] - self.retweet_pictures = [] - self.original = True - self.video_url = '' - self.original_pictures_list = [] - self.retweet_pictures_list = [] - self.media = {} - self.publish_place = '' - self.publish_time = '' - self.publish_tool = '' - self.up_num = 0 - self.retweet_num = 0 - self.comment_num = 0 + return asdict(self) - def __str__(self): + def __str__(self) -> str: """打印一条微博""" result = self.content + '\n' result += f'微博发布位置:{self.publish_place}\n' diff --git a/weibo_spider/writer/__init__.py b/weibo_spider/writer/__init__.py index f6b24bd6..3b6770a1 100644 --- a/weibo_spider/writer/__init__.py +++ b/weibo_spider/writer/__init__.py @@ -1,3 +1,4 @@ +from .writer import Writer from .csv_writer import CsvWriter from .json_writer import JsonWriter from .mongo_writer import MongoWriter @@ -7,4 +8,7 @@ from .kafka_writer import KafkaWriter from .post_writer import PostWriter -__all__ = [CsvWriter, TxtWriter, JsonWriter, MongoWriter, MySqlWriter, SqliteWriter, KafkaWriter, PostWriter] +__all__ = [ + Writer, CsvWriter, TxtWriter, JsonWriter, MongoWriter, MySqlWriter, + SqliteWriter, KafkaWriter, PostWriter +] From 2506b5772ac4052c658fec6cb17c87e4971dcb14 Mon Sep 17 00:00:00 2001 From: wchiways Date: Thu, 5 Feb 2026 00:27:17 +0800 Subject: [PATCH 3/5] refactor: eliminate code duplication and enhance config validation with pydantic - move date validation to datetime_util.is_valid_date - implement SpiderConfig using pydantic for robust validation - catch ValidationError in spider.py for user-friendly error reporting - add unit tests for config and datetime utilities --- requirements.txt | 3 +- tests/test_config.py | 53 ++++++++++ tests/test_datetime_util.py | 23 +++++ weibo_spider/config.py | 178 ++++++++++++++++++++++++---------- weibo_spider/config_util.py | 111 +-------------------- weibo_spider/datetime_util.py | 11 ++- weibo_spider/spider.py | 4 + 7 files changed, 223 insertions(+), 160 deletions(-) create mode 100644 tests/test_config.py create mode 100644 tests/test_datetime_util.py diff --git a/requirements.txt b/requirements.txt index eff3a468..91eccc7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ requests==2.32.4 tqdm==4.66.3 absl-py==0.12.0 browser_cookie3==0.20.1 -aiohttp \ No newline at end of file +aiohttp +pydantic \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..37dc8c1b --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,53 @@ +import unittest +from pydantic import ValidationError +from weibo_spider.config import SpiderConfig + +class TestSpiderConfig(unittest.TestCase): + def setUp(self): + self.base_config = { + 'user_id_list': ['123456'], + 'cookie': 'test_cookie', + 'filter': 0, + 'since_date': '2023-01-01', + 'end_date': 'now', + 'write_mode': ['csv'] + } + + def test_valid_config(self): + config = SpiderConfig(**self.base_config) + self.assertEqual(config.user_id_list, ['123456']) + self.assertEqual(config.filter, 0) + + def test_invalid_filter(self): + config = self.base_config.copy() + config['filter'] = 2 # Invalid + with self.assertRaises(ValidationError): + SpiderConfig(**config) + + def test_invalid_date(self): + config = self.base_config.copy() + config['since_date'] = 'invalid-date' + with self.assertRaises(ValidationError): + SpiderConfig(**config) + + def test_invalid_write_mode(self): + config = self.base_config.copy() + config['write_mode'] = ['invalid_mode'] + with self.assertRaises(ValidationError): + SpiderConfig(**config) + + def test_wait_ranges(self): + config = self.base_config.copy() + # Test start > end + config['random_wait_pages'] = [5, 1] + with self.assertRaises(ValidationError) as cm: + SpiderConfig(**config) + self.assertIn('等待范围起始值不能大于结束值', str(cm.exception)) + + # Test negative values + config['random_wait_pages'] = [-1, 5] + with self.assertRaises(ValidationError): + SpiderConfig(**config) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_datetime_util.py b/tests/test_datetime_util.py new file mode 100644 index 00000000..0e8916fe --- /dev/null +++ b/tests/test_datetime_util.py @@ -0,0 +1,23 @@ +import unittest +from datetime import datetime +from weibo_spider.datetime_util import str_to_time, is_valid_date + +class TestDatetimeUtil(unittest.TestCase): + def test_str_to_time(self): + # Test YYYY-MM-DD + dt1 = str_to_time('2023-01-01') + self.assertEqual(dt1, datetime(2023, 1, 1)) + + # Test YYYY-MM-DD HH:MM + dt2 = str_to_time('2023-01-01 12:30') + self.assertEqual(dt2, datetime(2023, 1, 1, 12, 30)) + + # Test invalid format (should raise ValueError) + with self.assertRaises(ValueError): + str_to_time('invalid-date') + + def test_is_valid_date(self): + self.assertTrue(is_valid_date('2023-01-01')) + self.assertTrue(is_valid_date('2023-01-01 12:30')) + self.assertFalse(is_valid_date('invalid-date')) + self.assertFalse(is_valid_date('2023/01/01')) # Wrong separator diff --git a/weibo_spider/config.py b/weibo_spider/config.py index 59a1956d..ffe255d5 100644 --- a/weibo_spider/config.py +++ b/weibo_spider/config.py @@ -1,53 +1,127 @@ -from dataclasses import dataclass, field from typing import List, Optional, Union, Dict, Any -from .config_util import validate_config - -@dataclass -class SpiderConfig: - """微博爬虫配置类 - - Attributes: - user_id_list: 要爬取的微博用户ID列表,可以是ID列表,也可以是包含ID的字典列表,或txt文件路径。 - cookie: 微博的Cookie,用于身份验证。 - filter: 过滤类型,0表示抓取全部微博,1表示只抓取原创微博。 - since_date: 抓取起始时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或整数(表示从今天起的前n天)。 - end_date: 抓取结束时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或"now"表示当前时间。 - random_wait_pages: 随机等待页数范围 [min, max],每爬取多少页暂停一次。 - random_wait_seconds: 随机等待时间范围 [min, max](秒),每次暂停等待的时长。 - global_wait: 全局等待配置 [[页数, 秒数], ...],例如每爬取1000页等待3600秒。 - write_mode: 结果保存类型列表,可包含 "txt", "csv", "json", "mongo", "mysql", "sqlite", "kafka", "post"。 - pic_download: 是否下载微博图片,0不下载,1下载。 - video_download: 是否下载微博视频,0不下载,1下载。 - file_download_timeout: 文件下载超时设置 [重试次数, 连接超时, 读取超时]。 - result_dir_name: 结果目录命名方式,0使用用户昵称,1使用用户ID。 - mysql_config: MySQL数据库连接配置字典。 - sqlite_config: SQLite数据库连接路径。 - kafka_config: Kafka配置字典。 - mongo_config: MongoDB配置字典。 - post_config: POST请求配置字典(用于数据推送)。 - """ - user_id_list: Union[List[Union[str, Dict[str, str]]], str] - cookie: str - filter: int = 0 - since_date: Union[int, str] = 0 - end_date: str = "now" - random_wait_pages: List[int] = field(default_factory=lambda: [1, 5]) - random_wait_seconds: List[int] = field(default_factory=lambda: [6, 10]) - global_wait: List[List[int]] = field(default_factory=lambda: [[1000, 3600]]) - write_mode: List[str] = field(default_factory=lambda: ["csv"]) - pic_download: int = 0 - video_download: int = 0 - file_download_timeout: List[int] = field(default_factory=lambda: [5, 5, 10]) - result_dir_name: int = 0 - mysql_config: Optional[Dict[str, Any]] = None - sqlite_config: Optional[str] = None - kafka_config: Optional[Dict[str, Any]] = None - mongo_config: Optional[Dict[str, Any]] = None - post_config: Optional[Dict[str, Any]] = None - - def __post_init__(self): - """初始化后的验证逻辑""" - # 使用现有的 validate_config 函数进行校验 - # 将 dataclass 转换为字典以适配旧有的验证函数 - config_dict = self.__dict__ - validate_config(config_dict) \ No newline at end of file +from pydantic import BaseModel, Field, field_validator, ConfigDict +from pathlib import Path +from datetime import datetime +from .datetime_util import is_valid_date + + +class SpiderConfig(BaseModel): + """微博爬虫配置类""" + model_config = ConfigDict(arbitrary_types_allowed=True) + + user_id_list: Union[List[Union[str, Dict[str, str]]], str] = Field( + description="要爬取的微博用户ID列表,可以是ID列表,也可以是包含ID的字典列表,或txt文件路径。" + ) + cookie: str = Field(description="微博的Cookie,用于身份验证。") + filter: int = Field(default=0, description="过滤类型,0表示抓取全部微博,1表示只抓取原创微博。") + since_date: Union[int, str] = Field( + default=0, + description="抓取起始时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或整数(表示从今天起的前n天)。" + ) + end_date: str = Field( + default="now", + description="抓取结束时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或'now'表示当前时间。" + ) + random_wait_pages: List[int] = Field( + default_factory=lambda: [1, 5], + description="随机等待页数范围 [min, max],每爬取多少页暂停一次。" + ) + random_wait_seconds: List[int] = Field( + default_factory=lambda: [6, 10], + description="随机等待时间范围 [min, max](秒),每次暂停等待的时长。" + ) + global_wait: List[List[int]] = Field( + default_factory=lambda: [[1000, 3600]], + description="全局等待配置 [[页数, 秒数], ...],例如每爬取1000页等待3600秒。" + ) + write_mode: List[str] = Field( + default_factory=lambda: ["csv"], + description="结果保存类型列表,可包含 'txt', 'csv', 'json', 'mongo', 'mysql', 'sqlite', 'kafka', 'post'。" + ) + pic_download: int = Field(default=0, description="是否下载微博图片,0不下载,1下载。") + video_download: int = Field(default=0, description="是否下载微博视频,0不下载,1下载。") + file_download_timeout: List[int] = Field( + default_factory=lambda: [5, 5, 10], + description="文件下载超时设置 [重试次数, 连接超时, 读取超时]。" + ) + result_dir_name: int = Field(default=0, description="结果目录命名方式,0使用用户昵称,1使用用户ID。") + mysql_config: Optional[Dict[str, Any]] = Field(default=None, description="MySQL数据库连接配置字典。") + sqlite_config: Optional[str] = Field(default=None, description="SQLite数据库连接路径。") + kafka_config: Optional[Dict[str, Any]] = Field(default=None, description="Kafka配置字典。") + mongo_config: Optional[Dict[str, Any]] = Field(default=None, description="MongoDB配置字典。") + post_config: Optional[Dict[str, Any]] = Field(default=None, description="POST请求配置字典(用于数据推送)。") + + @field_validator('filter', 'pic_download', 'video_download') + @classmethod + def check_binary(cls, v: int) -> int: + if v not in (0, 1): + raise ValueError(f'值应为0或1, 得到: {v}') + return v + + @field_validator('since_date') + @classmethod + def check_since_date(cls, v: Union[int, str]) -> Union[int, str]: + if isinstance(v, int): + return v + if not is_valid_date(str(v)): + raise ValueError(f'since_date值应为yyyy-mm-dd形式或整数, 得到: {v}') + return v + + @field_validator('end_date') + + @classmethod + + def check_end_date(cls, v: str) -> str: + + if v == 'now' or is_valid_date(v): + + return v + + raise ValueError(f'end_date值应为yyyy-mm-dd形式或"now", 得到: {v}') + + @field_validator('random_wait_pages', 'random_wait_seconds') + @classmethod + def check_wait_range(cls, v: List[int]) -> List[int]: + if not isinstance(v, list) or len(v) != 2: + raise ValueError('参数值应为包含两个整数的list类型') + if not all(isinstance(i, int) for i in v): + raise ValueError('列表中的值应为整数类型') + if min(v) < 1: + raise ValueError('列表中的值应大于0') + if v[0] > v[1]: + raise ValueError('等待范围起始值不能大于结束值') + return v + + @field_validator('global_wait') + @classmethod + def check_global_wait(cls, v: List[List[int]]) -> List[List[int]]: + for g in v: + if not isinstance(g, list) or len(g) != 2: + raise ValueError('参数内的值应为长度为2的list类型') + if not all(isinstance(i, int) and i >= 1 for i in g): + raise ValueError('列表中的值应为大于0的整数') + return v + + @field_validator('write_mode') + @classmethod + def check_write_mode(cls, v: List[str]) -> List[str]: + valid_modes = {'txt', 'csv', 'json', 'mongo', 'mysql', 'sqlite', 'kafka', 'post'} + for mode in v: + if mode not in valid_modes: + raise ValueError(f'{mode}为无效模式') + return v + + @field_validator('user_id_list') + @classmethod + def check_user_id_list(cls, v: Union[List[Union[str, Dict[str, str]]], str]) -> Union[List[Union[str, Dict[str, str]]], str]: + if isinstance(v, list): + return v + if not v.endswith('.txt'): + raise ValueError('user_id_list值应为list类型或txt文件路径') + + path = Path(v) + if not path.is_absolute(): + path = Path.cwd() / v + if not path.is_file(): + raise ValueError(f'不存在{path}文件') + return str(path) diff --git a/weibo_spider/config_util.py b/weibo_spider/config_util.py index 254b1eea..aa439f4f 100644 --- a/weibo_spider/config_util.py +++ b/weibo_spider/config_util.py @@ -1,118 +1,17 @@ import codecs import logging -import sys import browser_cookie3 from datetime import datetime from pathlib import Path import json +from .datetime_util import is_valid_date logger = logging.getLogger('spider.config_util') -def _is_date(date_str): - """判断日期格式是否正确""" - try: - if ':' in date_str: - datetime.strptime(date_str, '%Y-%m-%d %H:%M') - else: - datetime.strptime(date_str, '%Y-%m-%d') - return True - except ValueError: - return False - - -def validate_config(config): - """验证配置是否正确""" - - # 验证filter、pic_download、video_download - argument_list = ['filter', 'pic_download', 'video_download'] - for argument in argument_list: - if config[argument] != 0 and config[argument] != 1: - logger.warning(f'{config[argument]}值应为0或1,请重新输入') - sys.exit() - - # 验证since_date - since_date = config['since_date'] - if (not _is_date(str(since_date))) and (not isinstance(since_date, int)): - logger.warning('since_date值应为yyyy-mm-dd形式或整数,请重新输入') - sys.exit() - - # 验证end_date - end_date = str(config['end_date']) - if (not _is_date(end_date)) and (end_date != 'now'): - logger.warning('end_date值应为yyyy-mm-dd形式或"now",请重新输入') - sys.exit() - - # 验证random_wait_pages - random_wait_pages = config['random_wait_pages'] - if not isinstance(random_wait_pages, list): - logger.warning('random_wait_pages参数值应为list类型,请重新输入') - sys.exit() - if (not isinstance(min(random_wait_pages), int)) or (not isinstance( - max(random_wait_pages), int)): - logger.warning('random_wait_pages列表中的值应为整数类型,请重新输入') - sys.exit() - if min(random_wait_pages) < 1: - logger.warning('random_wait_pages列表中的值应大于0,请重新输入') - sys.exit() - - # 验证random_wait_seconds - random_wait_seconds = config['random_wait_seconds'] - if not isinstance(random_wait_seconds, list): - logger.warning('random_wait_seconds参数值应为list类型,请重新输入') - sys.exit() - if (not isinstance(min(random_wait_seconds), int)) or (not isinstance( - max(random_wait_seconds), int)): - logger.warning('random_wait_seconds列表中的值应为整数类型,请重新输入') - sys.exit() - if min(random_wait_seconds) < 1: - logger.warning('random_wait_seconds列表中的值应大于0,请重新输入') - sys.exit() - - # 验证global_wait - global_wait = config['global_wait'] - if not isinstance(global_wait, list): - logger.warning('global_wait参数值应为list类型,请重新输入') - sys.exit() - for g in global_wait: - if not isinstance(g, list): - logger.warning('global_wait参数内的值应为长度为2的list类型,请重新输入') - sys.exit() - if len(g) != 2: - logger.warning('global_wait参数内的list长度应为2,请重新输入') - sys.exit() - for i in g: - if (not isinstance(i, int)) or i < 1: - logger.warning('global_wait列表中的值应为大于0的整数,请重新输入') - sys.exit() - - # 验证write_mode - write_mode = ['txt', 'csv', 'json', 'mongo', 'mysql', 'sqlite', 'kafka','post'] - if not isinstance(config['write_mode'], list): - logger.warning('write_mode值应为list类型') - sys.exit() - for mode in config['write_mode']: - if mode not in write_mode: - logger.warning( - f'{mode}为无效模式,请从txt、csv、json、post、mongo、sqlite, kafka和mysql中挑选一个或多个作为write_mode') - sys.exit() - - # 验证user_id_list - user_id_list = config['user_id_list'] - if (not isinstance(user_id_list, - list)) and (not user_id_list.endswith('.txt')): - logger.warning('user_id_list值应为list类型或txt文件路径') - sys.exit() - if not isinstance(user_id_list, list): - if not Path(user_id_list).is_absolute(): - user_id_list = str(Path.cwd() / user_id_list) - if not Path(user_id_list).is_file(): - logger.warning(f'不存在{user_id_list}文件') - sys.exit() - - def get_user_config_list(file_name, default_since_date): """获取文件中的微博id信息""" + import sys with open(file_name, 'rb') as f: try: lines = f.read().splitlines() @@ -126,8 +25,8 @@ def get_user_config_list(file_name, default_since_date): if len(info) > 0 and info[0].isdigit(): user_config = {} user_config['user_uri'] = info[0] - if len(info) > 2 and _is_date(info[2]): - if len(info) > 3 and _is_date(info[2] + ' ' + info[3]): + if len(info) > 2 and is_valid_date(info[2]): + if len(info) > 3 and is_valid_date(info[2] + ' ' + info[3]): user_config['since_date'] = info[2] + ' ' + info[3] else: user_config['since_date'] = info[2] @@ -155,7 +54,7 @@ def update_user_config_file(user_config_file_path, user_uri, nickname, info.append(start_time) if len(info) == 2: info.append(start_time) - if len(info) > 3 and _is_date(info[2] + ' ' + info[3]): + if len(info) > 3 and is_valid_date(info[2] + ' ' + info[3]): del info[3] if len(info) > 2: info[2] = start_time diff --git a/weibo_spider/datetime_util.py b/weibo_spider/datetime_util.py index 1228af00..2b3b6375 100644 --- a/weibo_spider/datetime_util.py +++ b/weibo_spider/datetime_util.py @@ -1,10 +1,19 @@ from datetime import datetime -def str_to_time(text): +def str_to_time(text: str) -> datetime: """将字符串转换成时间类型""" if ':' in text: result = datetime.strptime(text, '%Y-%m-%d %H:%M') else: result = datetime.strptime(text, '%Y-%m-%d') return result + + +def is_valid_date(date_str: str) -> bool: + """判断日期格式是否正确""" + try: + str_to_time(date_str) + return True + except ValueError: + return False diff --git a/weibo_spider/spider.py b/weibo_spider/spider.py index aa771f63..968b0f7f 100644 --- a/weibo_spider/spider.py +++ b/weibo_spider/spider.py @@ -2,6 +2,7 @@ # -*- coding: UTF-8 -*- from typing import Dict, Any, List, Optional +from pydantic import ValidationError import json import logging import logging.config @@ -403,6 +404,9 @@ async def async_main(_): config = SpiderConfig(**config_dict) wb = Spider(config) await wb.start() # 爬取微博信息 + except ValidationError as e: + logger.error(f"配置验证失败:\n{e}") + sys.exit(1) except Exception as e: logger.exception(e) From 7e1d4b06efc0e42d281322b49c0983da15fff66f Mon Sep 17 00:00:00 2001 From: wchiways Date: Thu, 5 Feb 2026 00:48:47 +0800 Subject: [PATCH 4/5] feat: implement full-link async I/O for writers - Update Writer base class to use async/await - Integrate aiofiles for asynchronous file writing (txt, csv, json) - Migrate PostWriter to use aiohttp for async network requests - Update Spider to await writer methods - Add async writer tests --- requirements.txt | 3 +- tests/test_writers_async.py | 85 ++++++++++++++++++++++++++++ weibo_spider/spider.py | 8 +-- weibo_spider/writer/csv_writer.py | 20 +++++-- weibo_spider/writer/json_writer.py | 17 ++++-- weibo_spider/writer/kafka_writer.py | 4 +- weibo_spider/writer/mongo_writer.py | 4 +- weibo_spider/writer/mysql_writer.py | 4 +- weibo_spider/writer/post_writer.py | 45 ++++++++------- weibo_spider/writer/sqlite_writer.py | 4 +- weibo_spider/writer/txt_writer.py | 13 +++-- weibo_spider/writer/writer.py | 4 +- 12 files changed, 158 insertions(+), 53 deletions(-) create mode 100644 tests/test_writers_async.py diff --git a/requirements.txt b/requirements.txt index 91eccc7c..5b1a911a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ tqdm==4.66.3 absl-py==0.12.0 browser_cookie3==0.20.1 aiohttp -pydantic \ No newline at end of file +pydantic +aiofiles \ No newline at end of file diff --git a/tests/test_writers_async.py b/tests/test_writers_async.py new file mode 100644 index 00000000..5ef10cca --- /dev/null +++ b/tests/test_writers_async.py @@ -0,0 +1,85 @@ +import unittest +import asyncio +import os +import shutil +import sys +from unittest.mock import MagicMock, AsyncMock, patch +from datetime import datetime + +# Mock aiofiles before importing modules that use it +sys.modules['aiofiles'] = MagicMock() + +from weibo_spider.writer.txt_writer import TxtWriter +from weibo_spider.writer.csv_writer import CsvWriter +from weibo_spider.writer.json_writer import JsonWriter +from weibo_spider.user import User +from weibo_spider.weibo import Weibo + +class TestWritersAsync(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.user = User( + id='123456', + nickname='test_user', + weibo_num=100, + following=50, + followers=200 + ) + self.weibo = Weibo( + id='w1', + user_id='123456', + content='Test Weibo Content', + publish_time='2023-01-01', + up_num=10, + retweet_num=5, + comment_num=2 + ) + self.weibos = [self.weibo] + + @patch('weibo_spider.writer.txt_writer.aiofiles.open') + async def test_txt_writer(self, mock_aio_open): + file_path = 'test.txt' + writer = TxtWriter(file_path, filter=0) + + mock_f = AsyncMock() + mock_aio_open.return_value.__aenter__.return_value = mock_f + + await writer.write_user(self.user) + mock_f.write.assert_called() + + await writer.write_weibo(self.weibos) + mock_f.write.assert_called() + + @patch('weibo_spider.writer.csv_writer.aiofiles.open') + async def test_csv_writer(self, mock_aio_open): + file_path = 'test.csv' + writer = CsvWriter(file_path, filter=0) + + mock_f = AsyncMock() + mock_aio_open.return_value.__aenter__.return_value = mock_f + + await writer.write_weibo(self.weibos) + mock_f.write.assert_called() + # Verify content was written + args, _ = mock_f.write.call_args + self.assertIn('w1', args[0]) + self.assertIn('Test Weibo Content', args[0]) + + @patch('weibo_spider.writer.json_writer.aiofiles.open') + async def test_json_writer(self, mock_aio_open): + file_path = 'test.json' + writer = JsonWriter(file_path) + + mock_f = AsyncMock() + mock_aio_open.return_value.__aenter__.return_value = mock_f + # Mock read to return empty valid json + mock_f.read.return_value = '{}' + + await writer.write_user(self.user) + await writer.write_weibo(self.weibos) + + mock_f.write.assert_called() + args, _ = mock_f.write.call_args + self.assertIn('test_user', args[0]) + +if __name__ == '__main__': + unittest.main() diff --git a/weibo_spider/spider.py b/weibo_spider/spider.py index 968b0f7f..f605b326 100644 --- a/weibo_spider/spider.py +++ b/weibo_spider/spider.py @@ -132,12 +132,12 @@ async def write_weibo(self, weibos: List[Any]) -> None: for downloader in self.downloaders: await downloader.download_files(weibos, self.session) for writer in self.writers: - writer.write_weibo(weibos) + await writer.write_weibo(weibos) - def write_user(self, user: User) -> None: + async def write_user(self, user: User) -> None: """将用户信息写入数据库""" for writer in self.writers: - writer.write_user(user) + await writer.write_user(user) async def get_user_info(self, user_uri: str) -> None: """获取用户信息""" @@ -330,7 +330,7 @@ async def get_one_user(self, user_config: Dict[str, str]) -> None: logger.info('*' * 100) self.initialize_info(user_config) - self.write_user(self.user) + await self.write_user(self.user) logger.info('*' * 100) # 下载用户头像相册中的图片。 diff --git a/weibo_spider/writer/csv_writer.py b/weibo_spider/writer/csv_writer.py index 30a5d8c5..7776da56 100644 --- a/weibo_spider/writer/csv_writer.py +++ b/weibo_spider/writer/csv_writer.py @@ -1,5 +1,7 @@ import csv import logging +import io +import aiofiles from .writer import Writer @@ -29,18 +31,24 @@ def __init__(self, file_path, filter): except Exception as e: logger.exception(e) - def write_user(self, user): + async def write_user(self, user): self.user = user - def write_weibo(self, weibos): + async def write_weibo(self, weibos): """将爬取的信息写入csv文件""" try: result_data = [[getattr(w, kv[1]) for kv in self.result_headers] for w in weibos] - with open(self.file_path, 'a', encoding='utf-8-sig', - newline='') as f: - writer = csv.writer(f) - writer.writerows(result_data) + + output = io.StringIO() + writer = csv.writer(output) + writer.writerows(result_data) + content = output.getvalue() + output.close() + + async with aiofiles.open(self.file_path, 'a', encoding='utf-8-sig', + newline='') as f: + await f.write(content) logger.info(f'{len(weibos)}条微博写入csv文件完毕,保存路径:{self.file_path}') except Exception as e: logger.exception(e) diff --git a/weibo_spider/writer/json_writer.py b/weibo_spider/writer/json_writer.py index 026e3c75..26a1e09c 100644 --- a/weibo_spider/writer/json_writer.py +++ b/weibo_spider/writer/json_writer.py @@ -2,6 +2,7 @@ import json import logging from pathlib import Path +import aiofiles from .writer import Writer @@ -12,7 +13,7 @@ class JsonWriter(Writer): def __init__(self, file_path): self.file_path = file_path - def write_user(self, user): + async def write_user(self, user): self.user = user def _update_json_data(self, data, weibo_info): @@ -40,13 +41,17 @@ def _update_json_data(self, data, weibo_info): data['weibo'] = weibo_info return data - def write_weibo(self, weibos): + async def write_weibo(self, weibos): """将爬到的信息写入json文件""" data = {} if Path(self.file_path).is_file(): - with codecs.open(self.file_path, 'r', encoding='utf-8') as f: - data = json.load(f) + async with aiofiles.open(self.file_path, 'r', encoding='utf-8') as f: + content = await f.read() + if content: + data = json.loads(content) + data = self._update_json_data(data, [w.to_dict() for w in weibos]) - with codecs.open(self.file_path, 'w', encoding='utf-8') as f: - f.write(json.dumps(data, indent=4, ensure_ascii=False)) + + async with aiofiles.open(self.file_path, 'w', encoding='utf-8') as f: + await f.write(json.dumps(data, indent=4, ensure_ascii=False)) logger.info(f'{len(weibos)}条微博写入json文件完毕,保存路径:{self.file_path}') diff --git a/weibo_spider/writer/kafka_writer.py b/weibo_spider/writer/kafka_writer.py index ed4d6f69..feec20cb 100644 --- a/weibo_spider/writer/kafka_writer.py +++ b/weibo_spider/writer/kafka_writer.py @@ -25,13 +25,13 @@ def __init__(self, kafka_config): self.user_topics = list(kafka_config['user_topics']) logger.info(f'{kafka_config}') - def write_weibo(self, weibo): + async def write_weibo(self, weibo): for w in weibo: w.user_id = self.user.id for topic in self.weibo_topics: self.producer.send(topic, value=w.to_dict()) - def write_user(self, user): + async def write_user(self, user): self.user = user for topic in self.user_topics: diff --git a/weibo_spider/writer/mongo_writer.py b/weibo_spider/writer/mongo_writer.py index 2fd5714c..5547350d 100644 --- a/weibo_spider/writer/mongo_writer.py +++ b/weibo_spider/writer/mongo_writer.py @@ -45,7 +45,7 @@ def _info_to_mongodb(self, collection, info_list): '系统中可能没有安装或启动MongoDB数据库,请先根据系统环境安装或启动MongoDB,再运行程序') sys.exit() - def write_weibo(self, weibos): + async def write_weibo(self, weibos): """将爬取的微博信息写入MongoDB数据库""" weibo_list = [] for w in weibos: @@ -54,7 +54,7 @@ def write_weibo(self, weibos): self._info_to_mongodb('weibo', weibo_list) logger.info(f'{len(weibos)}条微博写入MongoDB数据库完毕') - def write_user(self, user): + async def write_user(self, user): """将爬取的用户信息写入MongoDB数据库""" self.user = user user_list = [user.to_dict()] diff --git a/weibo_spider/writer/mysql_writer.py b/weibo_spider/writer/mysql_writer.py index 7ae2473e..b98bc4f4 100644 --- a/weibo_spider/writer/mysql_writer.py +++ b/weibo_spider/writer/mysql_writer.py @@ -78,7 +78,7 @@ def _mysql_insert(self, table, data_list): finally: connection.close() - def write_weibo(self, weibos): + async def write_weibo(self, weibos): """将爬取的微博信息写入MySQL数据库""" # 创建'weibo'表 try: @@ -112,7 +112,7 @@ def write_weibo(self, weibos): except Exception as e: logger.exception(e) - def write_user(self, user): + async def write_user(self, user): """将爬取的用户信息写入MySQL数据库""" try: self.user = user diff --git a/weibo_spider/writer/post_writer.py b/weibo_spider/writer/post_writer.py index f6c54aac..e2e016ae 100644 --- a/weibo_spider/writer/post_writer.py +++ b/weibo_spider/writer/post_writer.py @@ -2,11 +2,10 @@ import json import logging import os -import requests +import asyncio +import aiohttp from .writer import Writer -from time import sleep -from requests.exceptions import RequestException logger = logging.getLogger('spider.post_writer') @@ -17,7 +16,7 @@ def __init__(self, post_config): self.api_token = post_config.get('api_token', None) self.dba_password = post_config.get('dba_password', None) - def write_user(self, user): + async def write_user(self, user): self.user = user def _update_json_data(self, data, weibo_info): @@ -29,31 +28,37 @@ def _update_json_data(self, data, weibo_info): data['weibo'] = weibo_info return data - def send_post_request_with_token(self, url, data, token, max_retries, backoff_factor): + async def send_post_request_with_token(self, url, data, token, max_retries, backoff_factor): headers = { 'Content-Type': 'application/json', 'api-token': f'{token}', } - for attempt in range(max_retries + 1): - try: - response = requests.post(url, json=data, headers=headers) - if response.status_code == requests.codes.ok: - return response.json() - else: - raise RequestException(f"Unexpected response status: {response.status_code}") - except RequestException as e: - if attempt < max_retries: - sleep(backoff_factor * (attempt + 1)) # 逐步增加等待时间,避免频繁重试 - continue - else: - logger.error(f"在尝试{max_retries}次发出POST连接后,请求失败:{e}") + async with aiohttp.ClientSession() as session: + for attempt in range(max_retries + 1): + try: + async with session.post(url, json=data, headers=headers) as response: + if response.status == 200: + return await response.json() + else: + # Continue to next attempt on non-200 status + if attempt < max_retries: + await asyncio.sleep(backoff_factor * (attempt + 1)) + continue + logger.error(f"Unexpected response status: {response.status}") + return None + except Exception as e: + if attempt < max_retries: + await asyncio.sleep(backoff_factor * (attempt + 1)) # 逐步增加等待时间,避免频繁重试 + continue + else: + logger.error(f"在尝试{max_retries}次发出POST连接后,请求失败:{e}") - def write_weibo(self, weibos): + async def write_weibo(self, weibos): """将爬到的信息POST到API""" data = {} data = self._update_json_data(data, [w.to_dict() for w in weibos]) if data: - self.send_post_request_with_token(self.api_url, data, self.api_token, 3, 2) + await self.send_post_request_with_token(self.api_url, data, self.api_token, 3, 2) logger.info(f'{len(weibos)}条微博通过POST发送到 {self.api_url}') else: logger.info('没有获取到微博,略过API POST') diff --git a/weibo_spider/writer/sqlite_writer.py b/weibo_spider/writer/sqlite_writer.py index e9c29e10..918fb8cf 100644 --- a/weibo_spider/writer/sqlite_writer.py +++ b/weibo_spider/writer/sqlite_writer.py @@ -50,7 +50,7 @@ def _sqlite_insert(self, table, data_list): finally: connection.close() - def write_weibo(self, weibos): + async def write_weibo(self, weibos): """将爬取的微博信息写入sqlite数据库""" # 创建'weibo'表 create_table = """ @@ -81,7 +81,7 @@ def write_weibo(self, weibos): self._sqlite_insert('weibo', weibo_list) logger.info(f'{len(weibos)}条微博写入sqlite数据库完毕') - def write_user(self, user): + async def write_user(self, user): """将爬取的用户信息写入sqlite数据库""" self.user = user diff --git a/weibo_spider/writer/txt_writer.py b/weibo_spider/writer/txt_writer.py index a6597052..198bbb9e 100644 --- a/weibo_spider/writer/txt_writer.py +++ b/weibo_spider/writer/txt_writer.py @@ -1,5 +1,6 @@ import logging import sys +import aiofiles from .writer import Writer @@ -23,17 +24,17 @@ def __init__(self, file_path, filter): ('up_num', '点赞数'), ('retweet_num', '转发数'), ('comment_num', '评论数'), ('publish_tool', '发布工具')] - def write_user(self, user): + async def write_user(self, user): self.user = user user_info = '\n'.join( [v + ':' + str(getattr(self.user, k)) for k, v in self.user_desc]) - with open(self.file_path, 'ab') as f: - f.write((self.user_header + ':\n' + user_info + '\n\n').encode( + async with aiofiles.open(self.file_path, 'ab') as f: + await f.write((self.user_header + ':\n' + user_info + '\n\n').encode( sys.stdout.encoding)) logger.info(f'{self.user.nickname}信息写入txt文件完毕,保存路径:{self.file_path}') - def write_weibo(self, weibo): + async def write_weibo(self, weibo): """将爬取的信息写入txt文件""" weibo_header = '' @@ -49,8 +50,8 @@ def write_weibo(self, weibo): for k, v in self.weibo_desc])) result = '\n\n'.join(temp_result) + '\n\n' - with open(self.file_path, 'ab') as f: - f.write((weibo_header + result).encode(sys.stdout.encoding)) + async with aiofiles.open(self.file_path, 'ab') as f: + await f.write((weibo_header + result).encode(sys.stdout.encoding)) logger.info(f'{len(weibo)}条微博写入txt文件完毕,保存路径:{self.file_path}') except Exception as e: logger.exception(e) diff --git a/weibo_spider/writer/writer.py b/weibo_spider/writer/writer.py index 45366510..1a72c596 100644 --- a/weibo_spider/writer/writer.py +++ b/weibo_spider/writer/writer.py @@ -7,11 +7,11 @@ def __init__(self): pass @abstractmethod - def write_weibo(self, weibo): + async def write_weibo(self, weibo): """给定微博信息,写入对应文本或数据库""" pass @abstractmethod - def write_user(self, user): + async def write_user(self, user): """给定用户信息,写入对应文本或数据库""" pass From 573c4de8a802e781b1fd1b4e34cf6b377e8176ef Mon Sep 17 00:00:00 2001 From: Chiway Wang <105968622+wchiways@users.noreply.github.com> Date: Thu, 5 Feb 2026 01:10:25 +0800 Subject: [PATCH 5/5] Import sys in check_cookie function Fix the error occurring during the build process --- weibo_spider/config_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/weibo_spider/config_util.py b/weibo_spider/config_util.py index aa439f4f..56b1322e 100644 --- a/weibo_spider/config_util.py +++ b/weibo_spider/config_util.py @@ -103,6 +103,7 @@ def update_cookie_config(cookie, user_config_file_path): def check_cookie(user_config_file_path): """Checks if user is logged in""" + import sys try: cookie = get_cookie() if cookie.get("MLOGIN", '0') == '0':