main.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import asyncio
  2. import functools
  3. import time
  4. import chromadb
  5. from aiohttp import web
  6. from chromadb.config import Settings
  7. from chromadb.utils import embedding_functions
  8. from chromadb.api.models.Collection import Collection
  9. import logging
  10. from pathlib import Path
  11. from config import DATABASE_DIR, HTTP_HOST, HTTP_PORT, QUIET, WEBDATA_DIR, COLLECTION_NAME
  12. from rich.logging import RichHandler
  13. from rich.console import Console
  14. from concurrent.futures import ThreadPoolExecutor
  15. import uuid
  16. log = logging.getLogger(__name__)
  17. COLLECTION: Collection = None
  18. executor = ThreadPoolExecutor(1)
  19. routes = web.RouteTableDef()
  20. WEBDATA_DIR = Path(WEBDATA_DIR)
  21. async def async_run(func, /, *args, **kwargs):
  22. loop = asyncio.get_running_loop()
  23. task = loop.run_in_executor(executor, functools.partial(func, *args, **kwargs))
  24. results = await asyncio.gather(task)
  25. return results[0]
  26. @routes.get('/')
  27. async def serve_index(request):
  28. return web.FileResponse(WEBDATA_DIR / 'index.html')
  29. @routes.get('/favicon.ico')
  30. async def serve_favicon(request):
  31. return web.FileResponse(WEBDATA_DIR / 'favicon.ico')
  32. @routes.post('/api/v1/search')
  33. async def add_items(request):
  34. data = await request.json()
  35. results = await async_run(
  36. COLLECTION.query,
  37. query_texts=[data.get('text', '')],
  38. n_results=10,
  39. )
  40. items = [
  41. [
  42. results['ids'][0][i],
  43. results['distances'][0][i],
  44. results['documents'][0][i],
  45. results['metadatas'][0][i],
  46. ] for i in range(len(results['ids'][0]))
  47. ]
  48. return web.json_response(items)
  49. @routes.post('/api/v1/item')
  50. async def add_items(request):
  51. data = await request.json()
  52. ids = []
  53. metadatas = []
  54. documents = []
  55. body = data.get('body', '')
  56. source = data.get('source')
  57. search_sources = source == '#'
  58. if search_sources:
  59. source = ''
  60. divisions = body.split('\n\n')
  61. for items in divisions:
  62. lines = [line for line in items.split('\n') if line.strip()]
  63. if len(lines) == 0:
  64. continue
  65. if search_sources and lines[0].startswith('#'):
  66. source = lines[0].lstrip('#').strip()
  67. continue
  68. title = ''
  69. if len(lines) >= 2:
  70. title = lines[0]
  71. lines = lines[1:]
  72. else:
  73. pass
  74. doc = ' '.join(lines)
  75. doc = ' '.join(doc.split()).strip() # this just replaces extra whitespaces with one.
  76. if len(doc) == 0:
  77. continue
  78. ids.append(uuid.uuid4().hex)
  79. documents.append(doc)
  80. metadatas.append({'type': 'question', 'source': source, 'added': int(time.time()), 'title': title})
  81. await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents)
  82. return web.json_response({'count': await async_run(COLLECTION.count)})
  83. @routes.delete('/api/v1/item')
  84. async def delete_item(request: web.BaseRequest):
  85. await async_run(COLLECTION.delete, ids=[request.query.get('id')])
  86. return
  87. @routes.put('/api/v1/item')
  88. async def edit_item(request: web.BaseRequest):
  89. data = await request.json()
  90. await async_run(
  91. COLLECTION.update,
  92. ids=[request.query.get('id')],
  93. metadatas=[data.get('metadata', {})],
  94. documents=[data.get('document', {})]
  95. )
  96. return
  97. # @routes.get('/api/v1/sources')
  98. # async def get_sources(request: web.BaseRequest):
  99. # await async_run(COLLECTION.get, where=)
  100. # return
  101. def run():
  102. global COLLECTION
  103. console = Console(color_system="standard", quiet=QUIET, width=180)
  104. logging.basicConfig(level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[
  105. RichHandler(console=console, enable_link_path=False)
  106. ])
  107. log.info("Loading collection: " + COLLECTION_NAME)
  108. sentence_transformer_ef = embedding_functions.InstructorEmbeddingFunction(
  109. model_name="hkunlp/instructor-large", device="cuda")
  110. client = chromadb.PersistentClient(
  111. path=DATABASE_DIR,
  112. settings=Settings(
  113. anonymized_telemetry=False,
  114. )
  115. )
  116. COLLECTION = client.get_or_create_collection(name=COLLECTION_NAME, embedding_function=sentence_transformer_ef)
  117. log.info(f"Using web root: {WEBDATA_DIR.absolute()}")
  118. app = web.Application(logger=log)
  119. app.add_routes(routes)
  120. app.router.add_static('/css', WEBDATA_DIR / 'css')
  121. app.router.add_static('/fonts', WEBDATA_DIR / 'fonts')
  122. app.router.add_static('/img', WEBDATA_DIR / 'img')
  123. app.router.add_static('/js', WEBDATA_DIR / 'js')
  124. log.info(f'Serving http at http://{HTTP_HOST}:{HTTP_PORT}')
  125. web.run_app(app, host=HTTP_HOST, port=HTTP_PORT, access_log=log,
  126. access_log_format='%a "%r" %s %b "%{Referer}i" "%{User-Agent}i"', print=False)
  127. if __name__ == '__main__':
  128. run()