foreign.scm 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. ;;; guile-openai --- An OpenAI API client for Guile
  2. ;;; Copyright © 2023 Andrew Whatson <whatson@tailcall.au>
  3. ;;;
  4. ;;; This file is part of guile-openai.
  5. ;;;
  6. ;;; guile-openai is free software: you can redistribute it and/or modify
  7. ;;; it under the terms of the GNU Affero General Public License as
  8. ;;; published by the Free Software Foundation, either version 3 of the
  9. ;;; License, or (at your option) any later version.
  10. ;;;
  11. ;;; guile-openai is distributed in the hope that it will be useful, but
  12. ;;; WITHOUT ANY WARRANTY; without even the implied warranty of
  13. ;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. ;;; Affero General Public License for more details.
  15. ;;;
  16. ;;; You should have received a copy of the GNU Affero General Public
  17. ;;; License along with guile-openai. If not, see
  18. ;;; <https://www.gnu.org/licenses/>.
  19. (define-module (openai utils foreign)
  20. #:use-module (ice-9 match)
  21. #:use-module (ice-9 format)
  22. #:use-module (ice-9 vlist)
  23. #:use-module (srfi srfi-1)
  24. #:use-module (srfi srfi-9)
  25. #:use-module (srfi srfi-9 gnu)
  26. #:use-module ((system foreign) #:prefix ffi:)
  27. #:use-module ((system foreign) #:select (define-wrapped-pointer-type))
  28. #:use-module (system foreign-library)
  29. #:export (c-type?
  30. c-type-name
  31. c-type-size
  32. int8 uint8 int16 uint16 int32 uint32 int64 uint64
  33. float double complex-double complex-float
  34. int unsigned-int long unsigned-long short unsigned-short
  35. size_t ssize_t ptrdiff_t intptr_t uintptr_t
  36. void pointer cstring bool
  37. define-foreign-type
  38. define-foreign-arg-type
  39. define-foreign-return-type
  40. define-foreign-enum-type
  41. define-foreign-pointer-type
  42. define-foreign-library
  43. define-foreign-function
  44. define-foreign-functions))
  45. ;;; C type marshalling
  46. (define-record-type <c-type>
  47. (%make-c-type name repr wrapper unwrapper)
  48. c-type?
  49. (name c-type-name)
  50. (repr c-type-repr)
  51. (wrapper c-type-wrapper)
  52. (unwrapper c-type-unwrapper))
  53. (define* (make-c-type name repr #:key wrap-result unwrap-args)
  54. (%make-c-type name repr wrap-result unwrap-args))
  55. (define* (print-c-type type #:optional port)
  56. (format port "#<c-type ~a ~a>"
  57. (c-type-name type)
  58. (c-type-name (get-base-type (c-type-repr type)))))
  59. (define (c-type-size type)
  60. (ffi:sizeof (c-type-repr type)))
  61. (set-record-type-printer! <c-type> print-c-type)
  62. (define-syntax-rule (define-foreign-type type-name base args ...)
  63. (define type-name
  64. (make-c-type (symbol->string 'type-name)
  65. (c-type-repr base)
  66. args ...)))
  67. (define-syntax-rule (define-foreign-arg-type type-name base unwrapper)
  68. (define-foreign-type type-name base #:unwrap-args unwrapper))
  69. (define-syntax-rule (define-foreign-return-type type-name base wrapper)
  70. (define-foreign-type type-name base #:wrap-result wrapper))
  71. ;;; Base types
  72. (define %base-types vlist-null)
  73. (define (register-base-type! type)
  74. (let ((repr (c-type-repr type)))
  75. (unless (has-base-type? repr)
  76. (set! %base-types (vhash-consv repr type %base-types)))))
  77. (define (has-base-type? repr)
  78. (and (vhash-assv repr %base-types) #t))
  79. (define (get-base-type repr)
  80. (match (vhash-assv repr %base-types)
  81. ((_ . type) type)))
  82. (define-syntax-rule (define-base-type type-name repr)
  83. (begin
  84. (define type-name
  85. (make-c-type (symbol->string 'type-name) repr
  86. #:wrap-result (lambda (res . _) res)
  87. #:unwrap-args (lambda (arg) arg)))
  88. (register-base-type! type-name)))
  89. (define-base-type int8 ffi:int8)
  90. (define-base-type uint8 ffi:uint8)
  91. (define-base-type int16 ffi:int16)
  92. (define-base-type uint16 ffi:uint16)
  93. (define-base-type int32 ffi:int32)
  94. (define-base-type uint32 ffi:uint32)
  95. (define-base-type int64 ffi:int64)
  96. (define-base-type uint64 ffi:uint64)
  97. (define-base-type float ffi:float)
  98. (define-base-type double ffi:double)
  99. (define-base-type complex-double ffi:complex-double)
  100. (define-base-type complex-float ffi:complex-float)
  101. (define-base-type int ffi:int)
  102. (define-base-type unsigned-int ffi:unsigned-int)
  103. (define-base-type long ffi:long)
  104. (define-base-type unsigned-long ffi:unsigned-long)
  105. (define-base-type short ffi:short)
  106. (define-base-type unsigned-short ffi:unsigned-short)
  107. (define-base-type size_t ffi:size_t)
  108. (define-base-type ssize_t ffi:ssize_t)
  109. (define-base-type ptrdiff_t ffi:ptrdiff_t)
  110. (define-base-type intptr_t ffi:intptr_t)
  111. (define-base-type uintptr_t ffi:uintptr_t)
  112. (define-base-type void ffi:void)
  113. (define-base-type pointer '*)
  114. ;;; Common types
  115. (define-foreign-type cstring pointer
  116. #:wrap-result (lambda (ptr . _) (ffi:pointer->string ptr))
  117. #:unwrap-args ffi:string->pointer)
  118. (define-foreign-type bool int
  119. #:wrap-result (lambda (int . _) (not (zero? int)))
  120. #:unwrap-args (lambda (bool) (if bool 1 0)))
  121. ;;; Enum types
  122. (define-syntax-rule (define-foreign-enum-type enum-name enum-base
  123. enumerator? enumerator-list
  124. int->enumerator enumerator->int
  125. (enumerator ...))
  126. (begin
  127. (define (enumerator? sym)
  128. (and (enumerator->int sym) #t))
  129. (define (enumerator-list)
  130. (%dfe-enum-symbols (enumerator ...)))
  131. (define enumerator->int
  132. (let ((lookup (alist->vhash (map cons
  133. (%dfe-enum-symbols (enumerator ...))
  134. (%dfe-enum-values (enumerator ...)))
  135. hashq)))
  136. (lambda (sym)
  137. (and=> (vhash-assq sym lookup) cdr))))
  138. (define int->enumerator
  139. (let ((lookup (alist->vhash (map cons
  140. (%dfe-enum-values (enumerator ...))
  141. (%dfe-enum-symbols (enumerator ...)))
  142. hashv)))
  143. (lambda (int)
  144. (and=> (vhash-assv int lookup) cdr))))
  145. (define-foreign-type enum-name enum-base
  146. #:wrap-result (lambda (int . _) (int->enumerator int))
  147. #:unwrap-args enumerator->int)))
  148. (define-syntax %dfe-enum-symbols
  149. (syntax-rules (=>)
  150. ((_ (args ...))
  151. (%dfe-enum-symbols (args ...) ()))
  152. ((_ (symbol => value args ...) (syms ...))
  153. (%dfe-enum-symbols (args ...) (syms ... symbol)))
  154. ((_ (symbol args ...) (syms ...))
  155. (%dfe-enum-symbols (args ...) (syms ... symbol)))
  156. ((_ () (syms ...))
  157. '(syms ...))))
  158. (define-syntax %dfe-enum-values
  159. (syntax-rules (=>)
  160. ((_ (args ...))
  161. (%dfe-enum-values (args ...) () -1))
  162. ((_ (symbol => value args ...) (vals ...) previous)
  163. (%dfe-enum-values (args ...) (vals ... value) value))
  164. ((_ (symbol args ...) (vals ...) previous)
  165. (%dfe-enum-values (args ...) (vals ... (1+ previous)) (1+ previous)))
  166. ((_ () (vals ...) previous)
  167. (list vals ...))))
  168. ;;; Pointer types
  169. (define-syntax-rule (define-foreign-pointer-type pointer-name record-type
  170. record? pointer->record record->pointer)
  171. (begin
  172. (define-wrapped-pointer-type record-type
  173. record? pointer->record record->pointer
  174. (lambda (rec port)
  175. (let ((address (ffi:pointer-address (record->pointer rec))))
  176. (format port "#<~a 0x~x>" 'pointer-name address))))
  177. (define-foreign-type pointer-name pointer
  178. #:wrap-result (lambda (ptr . _) (pointer->record ptr))
  179. #:unwrap-args record->pointer)))
  180. ;;; Function wrappers
  181. (define-syntax-rule (define-foreign-library library path args ...)
  182. (define library
  183. (load-foreign-library path args ...)))
  184. (define-syntax-rule (define-foreign-function library
  185. (function-name signature ...))
  186. (define function-name
  187. (apply wrapped-foreign-library-function library
  188. (symbol->string 'function-name)
  189. (%dff-parse-signature (signature ...)))))
  190. (define-syntax %dff-parse-signature
  191. (syntax-rules (->)
  192. ((_ (-> return-type) arg-types ...)
  193. (list #:return-type return-type
  194. #:arg-types (list arg-types ...)))
  195. ((_ (next rest ...) arg-types ...)
  196. (%dff-parse-signature (rest ...) arg-types ... next))))
  197. (define-syntax-rule (define-foreign-functions library
  198. (function-name signature ...) ...)
  199. (begin
  200. (define-foreign-function library
  201. (function-name signature ...))
  202. ...))
  203. (define* (wrapped-foreign-library-function library function-name
  204. #:key return-type arg-types)
  205. (let* ((wrap-result (c-type-wrapper return-type))
  206. (unwrappers (map c-type-unwrapper arg-types))
  207. (unwrap-args (lambda (args)
  208. (map (lambda (unwrap arg)
  209. (unwrap arg))
  210. unwrappers args)))
  211. (foreign-function
  212. (foreign-library-function library function-name
  213. #:return-type (c-type-repr return-type)
  214. #:arg-types (map c-type-repr arg-types))))
  215. (lambda args
  216. (let* ((raw-args (unwrap-args args))
  217. (raw-result (apply foreign-function raw-args))
  218. (result (apply wrap-result raw-result args)))
  219. result))))