markov_chains.lua 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. --[[
  2. Copyright 2019 NetherEran.
  3. This program is licensed under the MIT license
  4. This provides a markov chain object. This can be used to create
  5. themed nonsense based on input strings.
  6. How it works:
  7. -cut input strings into word lists by splitting at the space characters
  8. --put technical START and END words at the start and the end
  9. -store the propabilities for each word to follow after another word
  10. -to create own sentences, start at the START word and randomly
  11. walk according to the stored propabilities until the END word
  12. or a maximum sentence length is reached
  13. ]]
  14. local MAX_SENTENCE_LENGTH = 30 --in words
  15. --we use numeric indices for start and end nodes so input (which is
  16. --strings) can't mess with it
  17. local START = 1
  18. local END = 2
  19. --node related stuff
  20. local node_prototype = {}
  21. local function link(node, to_link)
  22. node.link_count = node.link_count + 1
  23. node.links[to_link] = (node.links[to_link] or 0) + 1
  24. end
  25. --returns a node object that stores what it links, where it's linked
  26. --from and the weight of its links
  27. local function MarkovNode()
  28. local prototype = {}
  29. prototype.links = {} --word -> weight
  30. prototype.linked_by = {} --word array
  31. prototype.link_count = 0
  32. for k, v in pairs(node_prototype)
  33. do
  34. prototype[k] = v
  35. end
  36. return prototype
  37. end
  38. --graph related stuff
  39. local graph_prototype = {}
  40. --links two words or increases the weight if they're already linked
  41. function graph_prototype:link(word1, word2)
  42. link(self.nodes[word1], word2)
  43. local reverse = self.nodes[word2].linked_by
  44. for _, name in ipairs(reverse)
  45. do
  46. if name == word1
  47. then
  48. return
  49. end
  50. end
  51. table.insert(reverse, word1)
  52. end
  53. --splits a string at its space characters
  54. local function cutup(str)
  55. local pieces = {}
  56. pieces[0] = START
  57. local i = 1
  58. for k in string.gmatch(str, "([^%s]+)")
  59. do
  60. pieces[i] = k
  61. i = i + 1
  62. end
  63. pieces[i] = END
  64. if pieces[1] == END
  65. then
  66. return
  67. else
  68. return pieces
  69. end
  70. end
  71. --cuts a line into words and puts them into the graph as nodes and does weight stuff
  72. function graph_prototype:learn_line(line)
  73. local input = cutup(line)
  74. if not input
  75. then
  76. return
  77. end
  78. for i, v in ipairs(input)
  79. do
  80. self:link(input[i - 1], v)
  81. end
  82. end
  83. --read a file and treat each line as a line to learn from
  84. function graph_prototype:learn_from_file(filename)
  85. for line in io.lines(filename)
  86. do
  87. self:learn_line(line)
  88. end
  89. end
  90. --randomly pick an order of words based on the links and weights
  91. function graph_prototype:randomwalk()
  92. local path = {}
  93. local current = START
  94. local sentence_length = 0
  95. while current ~= END and sentence_length < MAX_SENTENCE_LENGTH
  96. do
  97. sentence_length = sentence_length + 1
  98. local node = self.nodes[current]
  99. local rand = math.random(node.link_count + 1)
  100. for k, v in pairs(node.links)
  101. do
  102. rand = rand - v
  103. if rand <= 0
  104. then
  105. current = k
  106. table.insert(path, k)
  107. break
  108. end
  109. end
  110. end
  111. path[#path] = nil
  112. return path
  113. end
  114. --returns a sentense based on a randomwalk
  115. function graph_prototype:get_sentence()
  116. local path = self:randomwalk()
  117. return table.concat(path, " ")
  118. end
  119. --counts the nodes of a graph without START and END
  120. function graph_prototype:count_known_words()
  121. local count = 0
  122. for name, _ in pairs(self.nodes)
  123. do
  124. if not(name == START or name == END)
  125. then
  126. count = count + 1
  127. end
  128. end
  129. return count
  130. end
  131. --checks if a table contains a value
  132. local function contains(table, value)
  133. for k, v in pairs(table)
  134. do
  135. if v == value
  136. then
  137. return true
  138. end
  139. end
  140. return false
  141. end
  142. --returns dead nodes. A node is dead when it can not be reached from
  143. --START or END can't be reached from it
  144. local function get_dead_nodes(graph)
  145. --calculate the distance to each node from START
  146. local sdistances = {}
  147. sdistances[START] = 0
  148. local to_visit = {START}
  149. while #to_visit > 0
  150. do
  151. local current = table.remove(to_visit)
  152. local current_distance = sdistances[current] + 1
  153. for node, _ in pairs(graph.nodes[current].links)
  154. do
  155. if current_distance < (sdistances[node] or math.huge)
  156. then
  157. sdistances[node] = current_distance
  158. if not contains(to_visit, node)
  159. then
  160. table.insert(to_visit, node)
  161. end
  162. end
  163. end
  164. end
  165. --calculate the distance to each node from END (same as above but backwards)
  166. local edistances = {}
  167. edistances[END] = 0
  168. table.insert(to_visit, END)
  169. while #to_visit > 0
  170. do
  171. local current = table.remove(to_visit)
  172. local current_distance = edistances[current] + 1
  173. for _, node in pairs(graph.nodes[current].linked_by)
  174. do
  175. if current_distance < (edistances[node] or math.huge)
  176. then
  177. edistances[node] = current_distance
  178. if not contains(to_visit, node)
  179. then
  180. table.insert(to_visit, node)
  181. end
  182. end
  183. end
  184. end
  185. --a node is dead when if and only if it has no distance to START or END
  186. local dead = {}
  187. for name, node in pairs(graph.nodes)
  188. do
  189. if not (edistances[name] and sdistances[name])
  190. then
  191. table.insert(dead, name)
  192. end
  193. end
  194. return dead
  195. end
  196. --removes a node and the links to it
  197. local function remove(graph, to_remove)
  198. --don't remove start or end
  199. if to_remove == START or to_remove == END
  200. then
  201. return
  202. end
  203. --remove links to to_remove
  204. for i, name in ipairs(graph.nodes[to_remove].linked_by)
  205. do
  206. local node = graph.nodes[name]
  207. node.link_count = node.link_count - node.links[to_remove]
  208. node.links[to_remove] = nil
  209. end
  210. --remove notes that to_remove links nodes
  211. for name, _ in pairs(graph.nodes[to_remove].links)
  212. do
  213. local node = graph.nodes[name]
  214. for i, v in ipairs(node.linked_by)
  215. do
  216. if v == to_remove
  217. then
  218. table.remove(node.linked_by, i)
  219. break
  220. end
  221. end
  222. end
  223. --actually remove node
  224. graph.nodes[to_remove] = nil
  225. end
  226. --removes a node, and the nodes that become dead by removing it
  227. function graph_prototype:remove_node(to_remove)
  228. remove(self, to_remove)
  229. local dead = get_dead_nodes(self)
  230. for i, name in ipairs(dead)
  231. do
  232. remove(self, name)
  233. end
  234. end
  235. --returns the node that is has appeared the least often
  236. local function get_least_used(nodes)
  237. local least_uses = math.huge
  238. local least_used = nil
  239. for name, node in pairs(nodes)
  240. do
  241. if node.link_count < least_uses and
  242. not(name == END or name == START)
  243. then
  244. least_uses = node.link_count
  245. least_used = name
  246. end
  247. end
  248. return least_used
  249. end
  250. --cuts the amount of known words to 'to' words maximum.
  251. function graph_prototype:cull(to)
  252. local count = self:count_known_words()
  253. while count > to
  254. do
  255. local least_used = get_least_used(self.nodes)
  256. if least_used
  257. then
  258. self:remove_node(least_used)
  259. end
  260. count = self:count_known_words()
  261. end
  262. end
  263. local nodes_meta = {}
  264. --create new MarkovNode when indexed
  265. function nodes_meta.__index(self, key)
  266. self[key] = MarkovNode()
  267. return self[key]
  268. end
  269. --returns a MarkovGraph object
  270. local function MarkovGraph()
  271. local prototype = {}
  272. prototype.nodes = {} --word -> MarkovNode
  273. setmetatable(prototype.nodes, nodes_meta)
  274. for k, v in pairs(graph_prototype)
  275. do
  276. prototype[k] = v
  277. end
  278. return prototype
  279. end
  280. return MarkovGraph