wxgf.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from websocket import create_connection
  2. import os
  3. import logging
  4. import shutil
  5. import subprocess
  6. logger = logging.getLogger(__name__)
  7. WXGF_HEADER = b'wxgf'
  8. FAILURE_MESSAGE = b'FAILED'
  9. _HEVC_START_CODE_4 = b"\x00\x00\x00\x01"
  10. _HEVC_START_CODE_3 = b"\x00\x00\x01"
  11. def extract_hevc_bitstream_from_wxgf(data: bytes) -> bytes | None:
  12. """Extract Annex-B HEVC bitstream from WXGF container.
  13. Returns:
  14. HEVC bitstream bytes starting with a start-code, or None if unknown format.
  15. """
  16. if not data.startswith(WXGF_HEADER):
  17. return None
  18. start = data.find(_HEVC_START_CODE_4)
  19. if start < 0:
  20. start = data.find(_HEVC_START_CODE_3)
  21. if start < 0:
  22. return None
  23. return data[start:]
  24. def _subprocess_run_bytes(cmd: list[str], *, stdin: bytes) -> bytes | None:
  25. try:
  26. p = subprocess.run(
  27. cmd,
  28. input=stdin,
  29. stdout=subprocess.PIPE,
  30. stderr=subprocess.PIPE,
  31. check=False,
  32. )
  33. except FileNotFoundError:
  34. return None
  35. if p.returncode != 0:
  36. logger.debug(
  37. "Command failed (%s): rc=%d stderr=%s",
  38. " ".join(cmd),
  39. p.returncode,
  40. p.stderr[:2000].decode("utf-8", errors="replace"),
  41. )
  42. return None
  43. return p.stdout
  44. def _ffprobe_count_frames_hevc(hevc: bytes, *, ffprobe: str = "ffprobe") -> int | None:
  45. out = _subprocess_run_bytes(
  46. [
  47. ffprobe,
  48. "-v",
  49. "error",
  50. "-count_frames",
  51. "-select_streams",
  52. "v:0",
  53. "-show_entries",
  54. "stream=nb_read_frames",
  55. "-of",
  56. "default=nw=1:nk=1",
  57. "-f",
  58. "hevc",
  59. "-i",
  60. "pipe:0",
  61. ],
  62. stdin=hevc,
  63. )
  64. if out is None:
  65. return None
  66. try:
  67. return int(out.strip().splitlines()[-1])
  68. except Exception:
  69. return None
  70. def decode_wxgf_with_ffmpeg(
  71. data: bytes,
  72. *,
  73. ffmpeg: str = "ffmpeg",
  74. ffprobe: str = "ffprobe",
  75. ) -> bytes | None:
  76. """Decode WXGF into a standard image/animation using ffmpeg.
  77. Args:
  78. ffmpeg, ffprobe: path to ffmpeg and ffprobe executables.
  79. Returns:
  80. - PNG bytes for 1-frame WXGF
  81. - GIF bytes for multi-frame WXGF
  82. - None if decoding fails or ffmpeg/ffprobe is unavailable.
  83. """
  84. if shutil.which(ffmpeg) is None or shutil.which(ffprobe) is None:
  85. return None
  86. hevc = extract_hevc_bitstream_from_wxgf(data)
  87. if hevc is None:
  88. return None
  89. frames = _ffprobe_count_frames_hevc(hevc, ffprobe=ffprobe)
  90. if frames is not None and frames > 1:
  91. # Use palettegen/paletteuse for higher-quality gifs.
  92. out = _subprocess_run_bytes(
  93. [
  94. ffmpeg,
  95. "-hide_banner",
  96. "-loglevel",
  97. "error",
  98. "-f",
  99. "hevc",
  100. "-i",
  101. "pipe:0",
  102. "-filter_complex",
  103. "[0:v]split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse",
  104. "-loop",
  105. "0",
  106. "-f",
  107. "gif",
  108. "-",
  109. ],
  110. stdin=hevc,
  111. )
  112. if out is not None:
  113. return out
  114. # Default: decode the first frame to PNG (keeps quality and alpha).
  115. return _subprocess_run_bytes(
  116. [
  117. ffmpeg,
  118. "-hide_banner",
  119. "-loglevel",
  120. "error",
  121. "-f",
  122. "hevc",
  123. "-i",
  124. "pipe:0",
  125. "-frames:v",
  126. "1",
  127. "-f",
  128. "image2pipe",
  129. "-vcodec",
  130. "png",
  131. "-",
  132. ],
  133. stdin=hevc,
  134. )
  135. class WxgfDecoder:
  136. def __init__(self, server: str | None):
  137. """server: hostname:port"""
  138. if server is not None:
  139. if "://" not in server:
  140. server = "ws://" + server
  141. logger.info(f"Connecting to {server} ...")
  142. self.server = server
  143. self.ws = create_connection(server)
  144. def __del__(self):
  145. if self.has_server():
  146. self.ws.close()
  147. def has_server(self) -> bool:
  148. return hasattr(self, 'ws')
  149. def decode_with_server(self, data: bytes) -> bytes | None:
  150. assert data[:4] == WXGF_HEADER, data[:20]
  151. try:
  152. self.ws.send(data, opcode=0x2)
  153. except BrokenPipeError as e:
  154. logger.warning(f'Failed to send data to wxgf service. {e}. Reconnecting ..')
  155. self.ws = create_connection(self.server)
  156. self.ws.send(data, opcode=0x2)
  157. try:
  158. res = self.ws.recv()
  159. except Exception as e:
  160. logger.warning(f'Failed to recv data to wxgf service. {e}. Reconnecting ..')
  161. self.ws = create_connection(self.server)
  162. self.ws.send(data, opcode=0x2)
  163. res = self.ws.recv()
  164. if res == FAILURE_MESSAGE:
  165. return None
  166. return res
  167. def decode_with_cache(self, fname: str, data: bytes | None) -> bytes | None:
  168. """Decode and save cache.
  169. Args:
  170. fname: original file path. cache will be saved alongside.
  171. data: data to decode. None to use content of fname.
  172. """
  173. if data is None:
  174. with open(fname, 'rb') as f:
  175. data = f.read()
  176. out_fname = os.path.splitext(fname)[0] + '.dec'
  177. if os.path.exists(out_fname):
  178. with open(out_fname, 'rb') as f:
  179. return f.read()
  180. # Prefer host-side decoding via ffmpeg to avoid Android dependencies.
  181. res = decode_wxgf_with_ffmpeg(data)
  182. if res is None and self.has_server():
  183. res = self.decode_with_server(data)
  184. if res is not None:
  185. with open(out_fname, 'wb') as f:
  186. f.write(res)
  187. return res
  188. def is_wxgf_file(fname):
  189. with open(fname, 'rb') as f:
  190. return f.read(4) == WXGF_HEADER
  191. def is_wxgf_buffer(buf: bytes):
  192. return buf[:4] == WXGF_HEADER