diff --git a/.gitignore b/.gitignore index 91ee9d1..b1e0542 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ dist *.egg-info *.pyc *.db +*.swp .idea diff --git a/README.md b/README.md index b3149a0..9d6dca5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,14 @@ +# 变更说明 + +作者:[@raptorz](https://github.com/raptorz) + +* 支持python3 +* 符合PEP8 +* 增加token管理,包括js_ticket +* 增加模板消息发送功能 +* 可选是否验证https证书(原版为不验证) +* event_map改为可扩展 + # 微信公众号Python-SDK 作者: [@jeff_kit](http://twitter.com/jeff_kit) diff --git a/wechat/crypt.py b/wechat/crypt.py index 283598d..4cf9da6 100644 --- a/wechat/crypt.py +++ b/wechat/crypt.py @@ -1,4 +1,4 @@ -#encoding=utf-8 +# encoding=utf-8 import base64 import string @@ -8,10 +8,21 @@ import struct from Crypto.Cipher import AES import xml.etree.cElementTree as ET -import sys import socket -reload(sys) -sys.setdefaultencoding('utf-8') +import sys + + +if sys.version < "3": + reload(sys) + sys.setdefaultencoding('utf-8') + PY3 = False + + def str2bytes(s): + return s +else: + def str2bytes(s): + return s.encode("utf-8") if isinstance(s, str) else s + WXBizMsgCrypt_OK = 0 WXBizMsgCrypt_ValidateSignature_Error = -40001 @@ -63,6 +74,13 @@ def getSHA1(self, token, timestamp, nonce, encrypt): except Exception: return WXBizMsgCrypt_ComputeSignature_Error, None + @staticmethod + def getSignature(token, timestamp, nonce): + sign_ele = [token, timestamp, nonce] + sign_ele.sort() + s = "".join(sign_ele) + return hashlib.sha1(str2bytes(s)).hexdigest() + class XMLParse: """提供提取消息格式中的密文及生成回复消息格式的接口""" @@ -82,7 +100,7 @@ def extract(self, xmltext): xml_tree = ET.fromstring(xmltext) encrypt = xml_tree.find("Encrypt") touser_name = xml_tree.find("ToUserName") - if touser_name != None: + if touser_name is not None: touser_name = touser_name.text return WXBizMsgCrypt_OK, encrypt.text, touser_name except Exception: @@ -139,7 +157,7 @@ class Prpcrypt(object): """提供接收和推送给公众平台消息的加解密接口""" def __init__(self, key): - #self.key = base64.b64decode(key+"=") + # self.key = base64.b64decode(key+"=") self.key = key # 设置加解密模式为AES的CBC模式 self.mode = AES.MODE_CBC @@ -178,8 +196,8 @@ def decrypt(self, text, appid): try: pad = ord(plain_text[-1]) # 去掉补位字符串 - #pkcs7 = PKCS7Encoder() - #plain_text = pkcs7.encode(plain_text) + # pkcs7 = PKCS7Encoder() + # plain_text = pkcs7.encode(plain_text) # 去除16位随机字符串 content = plain_text[16:-pad] xml_len = socket.ntohl(struct.unpack("I", content[:4])[0]) @@ -202,17 +220,21 @@ def get_random_str(self): class WXBizMsgCrypt(object): def __init__(self, sToken, sEncodingAESKey, sCorpId): + (sToken, sEncodingAESKey, sCorpId) = \ + map(str2bytes, (sToken, sEncodingAESKey, sCorpId)) try: self.key = base64.b64decode(sEncodingAESKey+"=") assert len(self.key) == 32 except: throw_exception("[error]: EncodingAESKey unvalid !", FormatException) - #return WXBizMsgCrypt_IllegalAesKey) + # return WXBizMsgCrypt_IllegalAesKey) self.m_sToken = sToken self.m_sCorpid = sCorpId def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): + (sMsgSignature, sTimeStamp, sNonce, sEchoStr) = \ + map(str2bytes, (sMsgSignature, sTimeStamp, sNonce, sEchoStr)) sha1 = SHA1() ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr) @@ -225,6 +247,8 @@ def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): return ret, sReplyEchoStr def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None): + (sReplyMsg, sNonce, timestamp) = \ + map(str2bytes, (sReplyMsg, sNonce, timestamp)) pc = Prpcrypt(self.key) ret, encrypt = pc.encrypt(sReplyMsg, self.m_sCorpid) if ret != 0: @@ -241,6 +265,8 @@ def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None): return ret, xmlParse.generate(encrypt, signature, timestamp, sNonce) def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce): + (sPostData, sMsgSignature, sTimeStamp, sNonce) = \ + map(str2bytes, (sPostData, sMsgSignature, sTimeStamp, sNonce)) xmlParse = XMLParse() ret, encrypt, touser_name = xmlParse.extract(sPostData) if ret != 0: diff --git a/wechat/enterprise.py b/wechat/enterprise.py index eb509ff..461b40d 100644 --- a/wechat/enterprise.py +++ b/wechat/enterprise.py @@ -1,4 +1,4 @@ -#encoding=utf-8 +# encoding=utf-8 import requests import time @@ -9,6 +9,8 @@ WxVideoResponse, WxNewsResponse, APIError, WxEmptyResponse from .official import WxApplication as BaseApplication, WxBaseApi from .crypt import WXBizMsgCrypt +import sys + __all__ = ['WxRequest', 'WxResponse', 'WxArticle', 'WxImage', 'WxVoice', 'WxVideo', 'WxLink', 'WxTextResponse', @@ -60,7 +62,7 @@ def process(self, params, xml=None, token=None, corp_id=None, self.pre_process() rsp = func(self.req) self.post_process() - result = rsp.as_xml().encode('UTF-8') + result = rsp.as_xml().encode('UTF-8') if not result: return '' @@ -114,7 +116,7 @@ def get_access_token(self, url=None, **kwargs): params.update(kwargs) rsp = requests.get(url or self.api_entry + 'cgi-bin/gettoken', params=params, - verify=False) + verify=WxBaseApi.VERIFY) return self._process_response(rsp) def departments(self): @@ -335,16 +337,12 @@ def delete_menu(self, agentid): # OAuth2 def authorize_url(self, appid, redirect_uri, response_type='code', scope='snsapi_base', state=None): - # 变态的微信实现,参数的顺序也有讲究。。艹!这个实现太恶心,太恶心! - url = 'https://open.weixin.qq.com/connect/oauth2/authorize?' - rd_uri = urllib.urlencode({'redirect_uri': redirect_uri}) - url += 'appid=%s&' % appid - url += rd_uri - url += '&response_type=' + response_type - url += '&scope=' + scope + params = dict(appid=appid, redirect_uri=redirect_uri, response_type=response_type, scope=scope) if state: - url += '&state=' + state - return url + '#wechat_redirect' + params['state'] = state + url = '?'.join(['https://open.weixin.qq.com/connect/oauth2/authorize', urllib.urlencode(sorted(params.items()))]) + url = '#'.join([url, 'wechat_redirect']) + return url def get_user_info(self, agentid, code): return self._get('cgi-bin/user/getuserinfo', diff --git a/wechat/models.py b/wechat/models.py index d2a2491..e9f54f6 100644 --- a/wechat/models.py +++ b/wechat/models.py @@ -1,8 +1,13 @@ -#encoding=utf-8 +# encoding=utf-8 from xml.dom import minidom import collections import time +import sys + +if sys.version > "3": + long = int + unicode = str def kv2element(key, value, doc): diff --git a/wechat/official.py b/wechat/official.py index 32caf6b..230dd32 100644 --- a/wechat/official.py +++ b/wechat/official.py @@ -1,18 +1,21 @@ # encoding=utf-8 -from hashlib import sha1 +from functools import wraps import requests import json import tempfile import shutil import os -from .crypt import WXBizMsgCrypt +from .crypt import WXBizMsgCrypt, SHA1 +import sys +from datetime import datetime, timedelta from .models import WxRequest, WxResponse from .models import WxMusic, WxArticle, WxImage, WxVoice, WxVideo, WxLink from .models import WxTextResponse, WxImageResponse, WxVoiceResponse,\ WxVideoResponse, WxMusicResponse, WxNewsResponse, APIError, WxEmptyResponse + __all__ = ['WxRequest', 'WxResponse', 'WxMusic', 'WxArticle', 'WxImage', 'WxVoice', 'WxVideo', 'WxLink', 'WxTextResponse', 'WxImageResponse', 'WxVoiceResponse', 'WxVideoResponse', @@ -28,15 +31,29 @@ class WxApplication(object): APP_ID = None ENCODING_AES_KEY = None + def __init__(self): + self.event_handlers = { + 'subscribe': self.on_subscribe, + 'unsubscribe': self.on_unsubscribe, + 'SCAN': self.on_scan, + 'LOCATION': self.on_location_update, + 'CLICK': self.on_click, + 'VIEW': self.on_view, + 'scancode_push': self.on_scancode_push, + 'scancode_waitmsg': self.on_scancode_waitmsg, + 'pic_sysphoto': self.on_pic_sysphoto, + 'pic_photo_or_album': self.on_pic_photo_or_album, + 'pic_weixin': self.on_pic_weixin, + 'location_select': self.on_location_select + } + def is_valid_params(self, params): timestamp = params.get('timestamp', '') nonce = params.get('nonce', '') signature = params.get('signature', '') echostr = params.get('echostr', '') - sign_ele = [self.token, timestamp, nonce] - sign_ele.sort() - if(signature == sha1(''.join(sign_ele)).hexdigest()): + if (signature == SHA1.getSignature(self.token, timestamp, nonce)): return True, echostr else: return None @@ -62,8 +79,7 @@ def process(self, params, xml=None, token=None, app_id=None, aes_key=None): timestamp = params.get('timestamp', '') nonce = params.get('nonce', '') if encrypt_type == 'aes': - cpt = WXBizMsgCrypt(self.token, - self.aes_key, self.app_id) + cpt = WXBizMsgCrypt(self.token, self.aes_key, self.app_id) err, xml = cpt.DecryptMsg(xml, msg_signature, timestamp, nonce) if err: return 'decrypt message error, code : %s' % err @@ -108,28 +124,14 @@ def on_video(self, video): def on_location(self, loc): return WxTextResponse(self.UNSUPPORT_TXT, loc) - def event_map(self): - if getattr(self, 'event_handlers', None): - return self.event_handlers - return { - 'subscribe': self.on_subscribe, - 'unsubscribe': self.on_unsubscribe, - 'SCAN': self.on_scan, - 'LOCATION': self.on_location_update, - 'CLICK': self.on_click, - 'VIEW': self.on_view, - 'scancode_push': self.on_scancode_push, - 'scancode_waitmsg': self.on_scancode_waitmsg, - 'pic_sysphoto': self.on_pic_sysphoto, - 'pic_photo_or_album': self.on_pic_photo_or_album, - 'pic_weixin': self.on_pic_weixin, - 'location_select': self.on_location_select, - } - def on_event(self, event): - func = self.event_map().get(event.Event, None) + func = self.event_handlers.get(event.Event, self.on_other_event) return func(event) + def on_other_event(self, event): + # Unhandled event + return WxEmptyResponse() + def on_subscribe(self, sub): return WxTextResponse(self.WELCOME_TXT, sub) @@ -186,29 +188,31 @@ def post_process(self, rsp=None): pass +def retry_token(fn): + def wrapper(self, *args, **kwargs): + content, err = fn(self, *args, **kwargs) + if not content and err and err.code in [40001, 40014, 42001]: + self.token_manager.refresh_token(self.get_access_token) + return fn(self, *args, **kwargs) + else: + return content, err + return wrapper + + class WxBaseApi(object): API_PREFIX = 'https://api.weixin.qq.com/cgi-bin/' + VERIFY = True - def __init__(self, appid, appsecret, api_entry=None): + def __init__(self, appid, appsecret, token_manager, api_entry=None): self.appid = appid self.appsecret = appsecret - self._access_token = None + self.token_manager = token_manager self.api_entry = api_entry or self.API_PREFIX @property def access_token(self): - if not self._access_token: - token, err = self.get_access_token() - if not err: - self._access_token = token['access_token'] - return self._access_token - else: - return None - return self._access_token - - def set_access_token(self, token): - self._access_token = token + return self.token_manager.get_token(self.get_access_token) def _process_response(self, rsp): if rsp.status_code != 200: @@ -221,14 +225,16 @@ def _process_response(self, rsp): return None, APIError(content['errcode'], content['errmsg']) return content, None + @retry_token def _get(self, path, params=None): if not params: params = {} params['access_token'] = self.access_token rsp = requests.get(self.api_entry + path, params=params, - verify=False) + verify=WxBaseApi.VERIFY) return self._process_response(rsp) + @retry_token def _post(self, path, data, ctype='json'): headers = {'Content-type': 'application/json'} path = self.api_entry + path @@ -238,7 +244,8 @@ def _post(self, path, data, ctype='json'): path += '?access_token=' + self.access_token if ctype == 'json': data = json.dumps(data, ensure_ascii=False).encode('utf-8') - rsp = requests.post(path, data=data, headers=headers, verify=False) + rsp = requests.post(path, data=data, headers=headers, + verify=WxBaseApi.VERIFY) return self._process_response(rsp) def upload_media(self, mtype, file_path=None, file_content=None, @@ -261,7 +268,7 @@ def upload_media(self, mtype, file_path=None, file_content=None, f.close() media = open(tmp_path, 'rb') rsp = requests.post(path, files={'media': media}, - verify=False) + verify=WxBaseApi.VERIFY) media.close() os.remove(tmp_path) return self._process_response(rsp) @@ -270,7 +277,7 @@ def download_media(self, media_id, to_path, url='media/get'): rsp = requests.get(self.api_entry + url, params={'media_id': media_id, 'access_token': self.access_token}, - verify=False) + verify=WxBaseApi.VERIFY) if rsp.status_code == 200: save_file = open(to_path, 'wb') save_file.write(rsp.content) @@ -309,7 +316,7 @@ def get_access_token(self, url=None, **kwargs): if kwargs: params.update(kwargs) rsp = requests.get(url or self.api_entry + 'token', params=params, - verify=False) + verify=WxBaseApi.VERIFY) return self._process_response(rsp) def user_info(self, user_id, lang='zh_CN'): @@ -383,6 +390,11 @@ def send_news(self, to_user, news): {'touser': to_user, 'msgtype': 'news', 'news': {'articles': news}}) + def send_template(self, to_user, template_id, url, data): + return self._post('message/template/send', + {'touser': to_user, 'template_id': template_id, + 'url': url, 'data': data}) + def create_group(self, name): return self._post('groups/create', {'group': {'name': name}}) diff --git a/wechat/token_manager.py b/wechat/token_manager.py new file mode 100644 index 0000000..fda27e0 --- /dev/null +++ b/wechat/token_manager.py @@ -0,0 +1,82 @@ +# encoding=utf-8 + +from time import time, sleep + +import redis + +import logging + +logger = logging.getLogger(__name__) + + +class TokenManager(object): + def get_token(self, fn_get_access_token): + token = self.token + expires = self.expires + if not token and not expires: + for i in xrang(12): + sleep(5) + if self.token: + break + elif not token or expires and expires < time(): + self.expires = None + self.refresh_token(fn_get_access_token) + return self.token + + def refresh_token(self, fn_get_access_token): + token, err = fn_get_access_token() + if token and not err: + self.token = token['access_token'] + self.expires = time() + token['expires_in'] + else: + self.token = None + + +class LocalTokenManager(TokenManager): + def __init__(self): + self._access_token = None + self._expires = time() + + @property + def token(self): + return self._access_token + + @token.setter + def token(self, token): + self._access_token = token + + @property + def expires(self): + return self._expires + + @expires.setter + def expires(self, expires): + self._expires = expires + + +class RedisTokenManager(TokenManager): + def __init__(self, postfix="", **kwargs): + self.token_name = "_".join(["access_token", postfix]) + self.expires_name = "_".join(["access_token_expires", postfix]) + self.redis = redis.Redis(**kwargs) + if not self.expires: + self.expires = time() + + @property + def token(self): + token = self.redis.get(self.token_name) + return str(token, "utf-8") if token and isinstance( + token, bytes) else token + + @token.setter + def token(self, token): + self.redis.set(self.token_name, token) + + @property + def expires(self): + expires = self.redis.get(self.expires_name) + return expires + + @expires.setter + def expires(self, expires): + self.redis.set(self.expires_name, expires)