Przeglądaj źródła

remove ensure_unicode

Yuxin Wu 1 rok temu
rodzic
commit
e1a1b0b946
7 zmienionych plików z 58 dodań i 67 usunięć
  1. 1 3
      dump-html.py
  2. 31 31
      plot-num-msg-by-time.py
  3. 1 2
      wechat/avatar.py
  4. 1 9
      wechat/common/textutil.py
  5. 2 4
      wechat/msg.py
  6. 11 10
      wechat/parser.py
  7. 11 8
      wechat/render.py

+ 1 - 3
dump-html.py

@@ -6,7 +6,6 @@ import logging
 
 from wechat.parser import WeChatDBParser
 from wechat.res import Resource
-from wechat.common.textutil import ensure_unicode
 from wechat.render import HTMLRender
 
 logger = logging.getLogger("wechat")
@@ -26,7 +25,6 @@ def get_args():
 if __name__ == '__main__':
     args = get_args()
 
-    name = ensure_unicode(args.name)
     output_file = args.output
 
     parser = WeChatDBParser(args.db)
@@ -36,7 +34,7 @@ if __name__ == '__main__':
     except KeyError:
         sys.stderr.write(u"Valid Contacts: {}\n".format(
             u'\n'.join(parser.all_chat_nicknames)))
-        sys.stderr.write(u"Couldn't find the chat {}.".format(name));
+        sys.stderr.write(u"Couldn't find the chat {}.".format(args.name));
         sys.exit(1)
 
     res = Resource(parser, args.res,

+ 31 - 31
plot-num-msg-by-time.py

@@ -2,45 +2,45 @@
 # -*- coding: UTF-8 -*-
 
 from wechat.parser import WeChatDBParser
-from wechat.common.textutil import ensure_unicode
 
 from datetime import timedelta, datetime
-import numpy as np
 import matplotlib.pyplot as plt
-import sys, os
-
-if len(sys.argv) != 3:
-    sys.exit("Usage: {0} <path to decoded_database.db> <name>".format(sys.argv[0]))
-
-db_file = sys.argv[1]
-name = ensure_unicode(sys.argv[2])
-every_k_days = 2
-
-parser = WeChatDBParser(db_file)
-msgs = parser.msgs_by_chat[name]
-times = [x.createTime for x in msgs]
-start_time = times[0]
-diffs = [(x - start_time).days for x in times]
-max_day = diffs[-1]
-
-width = 20
-numbers = range((max_day / width + 1) * width + 1)[::width]
-labels = [(start_time + timedelta(x)).strftime("%m/%d") for x in numbers]
-plt.xticks(numbers, labels)
-plt.xlabel("Date")
-plt.ylabel("Number of msgs in k days")
-plt.hist(diffs, bins=max_day / every_k_days)
-plt.show()
+import sys
+
+
+if __name__ == '__main__':
+    if len(sys.argv) != 3:
+        sys.exit("Usage: {0} <path to decoded_database.db> <name>".format(sys.argv[0]))
+
+    db_file = sys.argv[1]
+    name = sys.argv[2]
+    every_k_days = 2
+
+    parser = WeChatDBParser(db_file)
+    msgs = parser.msgs_by_chat[name]
+    times = [x.createTime for x in msgs]
+    start_time = times[0]
+    diffs = [(x - start_time).days for x in times]
+    max_day = diffs[-1]
+
+    width = 20
+    numbers = range((max_day / width + 1) * width + 1)[::width]
+    labels = [(start_time + timedelta(x)).strftime("%m/%d") for x in numbers]
+    plt.xticks(numbers, labels)
+    plt.xlabel("Date")
+    plt.ylabel("Number of msgs in k days")
+    plt.hist(diffs, bins=max_day / every_k_days)
+    plt.show()
 
 # statistics by hour
 # I'm in a different time zone in this period:
 #TZ_DELTA = {(datetime(2014, 7, 13), datetime(2014, 10, 1)): -15}
 #def real_hour(x):
-    #for k, v in TZ_DELTA.items():
-        #if x > k[0] and x < k[1]:
-            #print x
-            #return (x.hour + v + 24) % 24
-    #return x.hour
+        #for k, v in TZ_DELTA.items():
+            #if x > k[0] and x < k[1]:
+                #print x
+                #return (x.hour + v + 24) % 24
+        #return x.hour
 #hours = [real_hour(x) for x in times]
 #plt.ylabel("Number of msgs")
 #plt.xlabel("Hour in a day")

+ 1 - 2
wechat/avatar.py

@@ -10,7 +10,7 @@ import logging
 import sqlite3
 logger = logging.getLogger(__name__)
 
-from .common.textutil import ensure_unicode, md5
+from .common.textutil import md5
 
 
 def _filename_priority(s):
@@ -72,7 +72,6 @@ class AvatarReader(object):
         """ username: `username` field in db.rcontact"""
         if not self._use_avt:
             return None
-        username = ensure_unicode(username)
         avtid = md5(username.encode('utf-8'))
 
         if self.avt_db is not None:

+ 1 - 9
wechat/common/textutil.py

@@ -3,13 +3,6 @@
 import hashlib
 import base64
 
-def ensure_unicode(s):
-    if type(s) == str:
-        return s
-    elif type(s) == bytes:
-        return s.decode('utf-8')
-    raise TypeError(f"type of string is {type(s)}")
-
 
 def md5(s):
     m = hashlib.md5()
@@ -25,6 +18,5 @@ def get_file_md5(fname):
         return md5(f.read())
 
 def safe_filename(fname):
-    filename = ensure_unicode(fname)
     return "".join(
-        [c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
+        [c for c in fname if c.isalpha() or c.isdigit() or c ==' ']).rstrip()

+ 2 - 4
wechat/msg.py

@@ -30,8 +30,6 @@ import xml.etree.ElementTree as ET
 import logging
 logger = logging.getLogger(__name__)
 
-from .common.textutil import ensure_unicode
-
 
 class WeChatMsg(object):
 
@@ -150,9 +148,9 @@ class WeChatMsg(object):
             self.type,
             self.talker_nickname if not self.isSend else 'me',
             self.createTime,
-            ensure_unicode(self.msg_str()))
+            self.msg_str())
         if self.imgPath:
-            ret = "{}|img:{}".format(ensure_unicode(ret.strip()), self.imgPath)
+            ret = "{}|img:{}".format(ret.strip(), self.imgPath)
             return ret
         else:
             return ret

+ 11 - 10
wechat/parser.py

@@ -8,7 +8,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 from .msg import WeChatMsg, TYPE_SYSTEM
-from .common.textutil import ensure_unicode
 
 """ tables in concern:
 emojiinfo
@@ -47,9 +46,9 @@ SELECT username,conRemark,nickname FROM rcontact
         for row in contacts:
             username, remark, nickname = row
             if remark:
-                self.contacts[username] = ensure_unicode(remark)
+                self.contacts[username] = remark
             else:
-                self.contacts[username] = ensure_unicode(nickname)
+                self.contacts[username] = nickname
 
         for k, v in self.contacts.items():
             self.contacts_rev[v].append(k)
@@ -147,16 +146,18 @@ SELECT {} FROM message
 
     # process the values in a row
     def _parse_msg_row(self, row):
-        """ parse a record of message into my format"""
+        """Parse a record of message into my format.
+
+        Note that message are read in binary format.
+        """
         values = dict(zip(WeChatDBParser.FIELDS, row))
         values['createTime'] = datetime.fromtimestamp(values['createTime']/ 1000)
-        try:
-            values['content'].decode()
-        except:
-            logger.warning(f"Invalid byte sequence in message content (type={values['type']}, createTime={values['createTime']})")
-            values['content'] = 'FAILED TO DECODE'
         if values['content']:
-            values['content'] = ensure_unicode(values['content'])
+            try:
+                values['content'] = values['content'].decode()
+            except:
+                logger.warning(f"Invalid byte sequence in message content (type={values['type']}, createTime={values['createTime']})")
+                values['content'] = 'FAILED TO DECODE'
         else:
             values['content'] = ''
 

+ 11 - 8
wechat/render.py

@@ -21,7 +21,7 @@ except ImportError:
     css_compress = lambda x: x
 
 from .msg import *
-from .common.textutil import ensure_unicode, get_file_b64
+from .common.textutil import get_file_b64
 from .common.progress import ProgressReporter
 from .common.timer import timing
 from .smiley import SmileyProvider
@@ -52,8 +52,10 @@ def get_template(name: str | int) -> str | None:
 
 class HTMLRender(object):
     def __init__(self, parser, res=None):
-        self.html = ensure_unicode(open(HTML_FILE).read())
-        self.time_html = open(TIME_HTML_FILE).read()
+        with open(HTML_FILE) as f:
+            self.html = f.read()
+        with open(TIME_HTML_FILE) as f:
+            self.time_html = f.read()
         self.parser = parser
         self.res = res
         assert self.res is not None, \
@@ -64,8 +66,8 @@ class HTMLRender(object):
         self.css_string = []    # css to add
         for css in css_files:
             logger.info("Loading {}".format(os.path.basename(css)))
-            css = ensure_unicode((open(css).read()))
-            self.css_string.append(css)
+            with open(css) as f:
+                self.css_string.append(f.read())
 
         js_files = glob.glob(os.path.join(LIB_PATH, 'static/*.js'))
         # to load jquery before other js
@@ -73,8 +75,8 @@ class HTMLRender(object):
         self.js_string = []
         for js in js_files:
             logger.info("Loading {}".format(os.path.basename(js)))
-            js = ensure_unicode(open(js).read())
-            self.js_string.append(js)
+            with open(js) as f:
+                self.js_string.append(f.read())
 
         self.unknown_type_cnt = Counter()
 
@@ -229,7 +231,8 @@ class HTMLRender(object):
                            )
 
     def prepare_avatar_css(self, talkers):
-        avatar_tpl= ensure_unicode(open(FRIEND_AVATAR_CSS_FILE).read())
+        with open(FRIEND_AVATAR_CSS_FILE) as f:
+            avatar_tpl = f.read()
         my_avatar = self.res.get_avatar(self.parser.username)
         css = avatar_tpl.format(name='me', avatar=my_avatar)