main.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import asyncio
  2. import copy
  3. import datetime
  4. import gzip
  5. import os
  6. import pickle
  7. from time import time
  8. import pytz
  9. from tqdm import tqdm
  10. import utils.constants as constants
  11. from updates.epg import get_epg
  12. from updates.fofa import get_channels_by_fofa
  13. from updates.hotel import get_channels_by_hotel
  14. from updates.multicast import get_channels_by_multicast
  15. from updates.online_search import get_channels_by_online_search
  16. from updates.subscribe import get_channels_by_subscribe_urls
  17. from utils.channel import (
  18. get_channel_items,
  19. append_total_data,
  20. test_speed,
  21. write_channel_to_file, sort_channel_result,
  22. )
  23. from utils.config import config
  24. from utils.tools import (
  25. get_pbar_remaining,
  26. get_ip_address,
  27. process_nested_dict,
  28. format_interval,
  29. check_ipv6_support,
  30. get_urls_from_file,
  31. get_version_info,
  32. join_url,
  33. get_urls_len,
  34. merge_objects
  35. )
  36. from utils.types import CategoryChannelData
  37. class UpdateSource:
  38. def __init__(self):
  39. self.update_progress = None
  40. self.run_ui = False
  41. self.tasks = []
  42. self.channel_items: CategoryChannelData = {}
  43. self.hotel_fofa_result = {}
  44. self.hotel_foodie_result = {}
  45. self.multicast_result = {}
  46. self.subscribe_result = {}
  47. self.online_search_result = {}
  48. self.epg_result = {}
  49. self.channel_data: CategoryChannelData = {}
  50. self.pbar = None
  51. self.total = 0
  52. self.start_time = None
  53. self.stop_event = None
  54. self.ipv6_support = False
  55. self.now = None
  56. async def visit_page(self, channel_names: list[str] = None):
  57. tasks_config = [
  58. ("hotel_fofa", get_channels_by_fofa, "hotel_fofa_result"),
  59. ("multicast", get_channels_by_multicast, "multicast_result"),
  60. ("hotel_foodie", get_channels_by_hotel, "hotel_foodie_result"),
  61. ("subscribe", get_channels_by_subscribe_urls, "subscribe_result"),
  62. (
  63. "online_search",
  64. get_channels_by_online_search,
  65. "online_search_result",
  66. ),
  67. ("epg", get_epg, "epg_result"),
  68. ]
  69. for setting, task_func, result_attr in tasks_config:
  70. if (
  71. setting == "hotel_foodie" or setting == "hotel_fofa"
  72. ) and config.open_hotel == False:
  73. continue
  74. if config.open_method[setting]:
  75. if setting == "subscribe":
  76. subscribe_urls = get_urls_from_file(constants.subscribe_path)
  77. whitelist_urls = get_urls_from_file(constants.whitelist_path)
  78. if not os.getenv("GITHUB_ACTIONS") and config.cdn_url:
  79. subscribe_urls = [join_url(config.cdn_url, url) if "raw.githubusercontent.com" in url else url
  80. for url in subscribe_urls]
  81. task = asyncio.create_task(
  82. task_func(subscribe_urls,
  83. names=channel_names,
  84. whitelist=whitelist_urls,
  85. callback=self.update_progress
  86. )
  87. )
  88. elif setting == "hotel_foodie" or setting == "hotel_fofa":
  89. task = asyncio.create_task(task_func(callback=self.update_progress))
  90. else:
  91. task = asyncio.create_task(
  92. task_func(channel_names, callback=self.update_progress)
  93. )
  94. self.tasks.append(task)
  95. setattr(self, result_attr, await task)
  96. def pbar_update(self, name: str = "", item_name: str = ""):
  97. if self.pbar.n < self.total:
  98. self.pbar.update()
  99. self.update_progress(
  100. f"正在进行{name}, 剩余{self.total - self.pbar.n}个{item_name}, 预计剩余时间: {get_pbar_remaining(n=self.pbar.n, total=self.total, start_time=self.start_time)}",
  101. int((self.pbar.n / self.total) * 100),
  102. )
  103. async def main(self):
  104. try:
  105. main_start_time = time()
  106. if config.open_update:
  107. self.channel_items = get_channel_items()
  108. channel_names = [
  109. name
  110. for channel_obj in self.channel_items.values()
  111. for name in channel_obj.keys()
  112. ]
  113. if not channel_names:
  114. print(f"❌ No channel names found! Please check the {config.source_file}!")
  115. return
  116. await self.visit_page(channel_names)
  117. self.tasks = []
  118. append_total_data(
  119. self.channel_items.items(),
  120. self.channel_data,
  121. self.hotel_fofa_result,
  122. self.multicast_result,
  123. self.hotel_foodie_result,
  124. self.subscribe_result,
  125. self.online_search_result,
  126. )
  127. cache_result = self.channel_data
  128. test_result = {}
  129. if config.open_speed_test:
  130. urls_total = get_urls_len(self.channel_data)
  131. test_data = copy.deepcopy(self.channel_data)
  132. process_nested_dict(
  133. test_data,
  134. seen=set(),
  135. filter_host=config.speed_test_filter_host,
  136. ipv6_support=self.ipv6_support
  137. )
  138. self.total = get_urls_len(test_data)
  139. print(f"Total urls: {urls_total}, need to test speed: {self.total}")
  140. self.update_progress(
  141. f"正在进行测速, 共{urls_total}个接口, {self.total}个接口需要进行测速",
  142. 0,
  143. )
  144. self.start_time = time()
  145. self.pbar = tqdm(total=self.total, desc="Speed test")
  146. test_result = await test_speed(
  147. test_data,
  148. ipv6=self.ipv6_support,
  149. callback=lambda: self.pbar_update(name="测速", item_name="接口"),
  150. )
  151. cache_result = merge_objects(cache_result, test_result, match_key="url")
  152. self.pbar.close()
  153. self.channel_data = sort_channel_result(
  154. self.channel_data,
  155. result=test_result,
  156. filter_host=config.speed_test_filter_host,
  157. ipv6_support=self.ipv6_support
  158. )
  159. self.update_progress(f"正在生成结果文件", 0)
  160. write_channel_to_file(
  161. self.channel_data,
  162. epg=self.epg_result,
  163. ipv6=self.ipv6_support,
  164. first_channel_name=channel_names[0],
  165. )
  166. if config.open_history:
  167. if os.path.exists(constants.cache_path):
  168. with gzip.open(constants.cache_path, "rb") as file:
  169. try:
  170. cache = pickle.load(file)
  171. except EOFError:
  172. cache = {}
  173. cache_result = merge_objects(cache, cache_result, match_key="url")
  174. with gzip.open(constants.cache_path, "wb") as file:
  175. pickle.dump(cache_result, file)
  176. print(
  177. f"🥳 Update completed! Total time spent: {format_interval(time() - main_start_time)}."
  178. )
  179. if self.run_ui:
  180. open_service = config.open_service
  181. service_tip = ", 可使用以下地址进行观看" if open_service else ""
  182. tip = (
  183. f"✅ 服务启动成功{service_tip}"
  184. if open_service and config.open_update == False
  185. else f"🥳更新完成, 耗时: {format_interval(time() - main_start_time)}{service_tip}"
  186. )
  187. self.update_progress(
  188. tip,
  189. 100,
  190. finished=True,
  191. url=f"{get_ip_address()}" if open_service else None,
  192. now=self.now
  193. )
  194. except asyncio.exceptions.CancelledError:
  195. print("Update cancelled!")
  196. async def start(self, callback=None):
  197. def default_callback(self, *args, **kwargs):
  198. pass
  199. self.update_progress = callback or default_callback
  200. self.run_ui = True if callback else False
  201. if self.run_ui:
  202. self.update_progress(f"正在检查网络是否支持IPv6", 0)
  203. self.ipv6_support = config.ipv6_support or check_ipv6_support()
  204. if not os.getenv("GITHUB_ACTIONS") and config.update_interval:
  205. await self.scheduler(asyncio.Event())
  206. else:
  207. await self.main()
  208. def stop(self):
  209. for task in self.tasks:
  210. task.cancel()
  211. self.tasks = []
  212. if self.pbar:
  213. self.pbar.close()
  214. if self.stop_event:
  215. self.stop_event.set()
  216. async def scheduler(self, stop_event):
  217. self.stop_event = stop_event
  218. while not stop_event.is_set():
  219. self.now = datetime.datetime.now(pytz.timezone(config.time_zone))
  220. await self.main()
  221. next_time = self.now + datetime.timedelta(hours=config.update_interval)
  222. print(f"🕒 Next update time: {next_time:%Y-%m-%d %H:%M:%S}")
  223. try:
  224. await asyncio.wait_for(stop_event.wait(), timeout=config.update_interval * 3600)
  225. except asyncio.TimeoutError:
  226. continue
  227. if __name__ == "__main__":
  228. info = get_version_info()
  229. print(f"✡️ {info['name']} Version: {info['version']}")
  230. loop = asyncio.new_event_loop()
  231. asyncio.set_event_loop(loop)
  232. update_source = UpdateSource()
  233. loop.run_until_complete(update_source.start())