main.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import asyncio
  2. import codecs
  3. import functools
  4. import io
  5. import time
  6. import chromadb
  7. from aiohttp import web
  8. from chromadb.config import Settings
  9. from chromadb.utils import embedding_functions
  10. from chromadb.api.models.Collection import Collection
  11. import logging
  12. from pathlib import Path
  13. from config import (
  14. DATABASE_DIR,
  15. HTTP_HOST,
  16. HTTP_PORT,
  17. QUIET,
  18. WEBDATA_DIR,
  19. COLLECTION_NAME,
  20. MODEL_NAME,
  21. TORCH_DEVICE,
  22. )
  23. from rich.logging import RichHandler
  24. from rich.console import Console
  25. from concurrent.futures import ThreadPoolExecutor
  26. import uuid
  27. import csv
  28. log = logging.getLogger(__name__)
  29. COLLECTION: Collection = None
  30. executor = ThreadPoolExecutor(1)
  31. routes = web.RouteTableDef()
  32. WEBDATA_DIR = Path(WEBDATA_DIR)
  33. async def async_run(func, /, *args, **kwargs):
  34. loop = asyncio.get_running_loop()
  35. task = loop.run_in_executor(executor, functools.partial(func, *args, **kwargs))
  36. results = await asyncio.gather(task)
  37. return results[0]
  38. @routes.get('/')
  39. async def serve_index(request):
  40. return web.FileResponse(WEBDATA_DIR / 'index.html')
  41. @routes.get('/favicon.ico')
  42. async def serve_favicon(request):
  43. return web.FileResponse(WEBDATA_DIR / 'favicon.ico')
  44. @routes.post('/api/v1/search')
  45. async def add_items(request):
  46. data = await request.json()
  47. results = await async_run(
  48. COLLECTION.query,
  49. query_texts=[data.get('text', '')],
  50. n_results=10,
  51. )
  52. items = [
  53. [
  54. results['ids'][0][i],
  55. results['distances'][0][i],
  56. results['documents'][0][i],
  57. results['metadatas'][0][i],
  58. ] for i in range(len(results['ids'][0]))
  59. ]
  60. return web.json_response(items)
  61. @routes.post('/api/v1/item')
  62. async def add_items(request):
  63. data = await request.json()
  64. ids = []
  65. metadatas = []
  66. documents = []
  67. body = data.get('body', '')
  68. source = data.get('source')
  69. search_sources = source == '#'
  70. if search_sources:
  71. source = ''
  72. divisions = body.split('\n\n')
  73. for items in divisions:
  74. lines = [line for line in items.split('\n') if line.strip()]
  75. if len(lines) == 0:
  76. continue
  77. if search_sources and lines[0].startswith('#'):
  78. source = lines[0].lstrip('#').strip()
  79. continue
  80. title = ''
  81. if len(lines) >= 2:
  82. title = lines[0]
  83. lines = lines[1:]
  84. else:
  85. pass
  86. doc = ' '.join(lines)
  87. doc = ' '.join(doc.split()).strip() # this just replaces extra whitespaces with one.
  88. if len(doc) == 0:
  89. continue
  90. ids.append(uuid.uuid4().hex)
  91. documents.append(doc)
  92. metadatas.append({'type': 'question', 'source': source, 'added': int(time.time()), 'title': title})
  93. await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents)
  94. return web.json_response({'count': await async_run(COLLECTION.count)})
  95. @routes.delete('/api/v1/item')
  96. async def delete_item(request: web.BaseRequest):
  97. await async_run(COLLECTION.delete, ids=[request.query.get('id')])
  98. return
  99. @routes.get('/api/v1/database')
  100. async def get_database(request: web.BaseRequest):
  101. results = await async_run(COLLECTION.get)
  102. buffer = io.StringIO()
  103. writer = csv.writer(buffer)
  104. writer.writerow(['ID', 'Source', 'Title', 'Document'])
  105. for i in range(len(results['ids'])):
  106. writer.writerow([
  107. results["ids"][i], results['metadatas'][i]['source'],
  108. results["metadatas"][i]["title"], results['documents'][i]
  109. ])
  110. return web.Response(text=buffer.getvalue(), status=200, content_type='text/csv',
  111. headers={'Content-Disposition': 'attachment; filename=ipt_questions.csv'})
  112. @routes.post('/api/v1/database')
  113. async def get_database(request: web.BaseRequest):
  114. data = await request.post()
  115. ids = []
  116. documents = []
  117. metadatas = []
  118. body = codecs.getreader('utf-8')(data['database'].file)
  119. reader = csv.reader(body)
  120. header = next(reader)
  121. if ','.join(header) != 'ID,Source,Title,Document':
  122. return web.Response(text='Invalid CSV header', status=400)
  123. for row in reader:
  124. ids.append(row[0])
  125. documents.append(row[3])
  126. metadatas.append({"type": "question", "source": row[1], "added": int(time.time()), "title": row[2]})
  127. await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents)
  128. return web.json_response({'count': await async_run(COLLECTION.count)})
  129. @routes.put('/api/v1/item')
  130. async def edit_item(request: web.BaseRequest):
  131. data = await request.json()
  132. await async_run(
  133. COLLECTION.update,
  134. ids=[request.query.get('id')],
  135. metadatas=[data.get('metadata', {})],
  136. documents=[data.get('document', {})]
  137. )
  138. return
  139. def run():
  140. global COLLECTION
  141. console = Console(color_system="standard", quiet=QUIET, width=180)
  142. logging.basicConfig(level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[
  143. RichHandler(console=console, enable_link_path=False)
  144. ])
  145. log.info("Loading collection: " + COLLECTION_NAME)
  146. sentence_transformer_ef = embedding_functions.InstructorEmbeddingFunction(
  147. model_name=MODEL_NAME, device=TORCH_DEVICE)
  148. client = chromadb.PersistentClient(
  149. path=DATABASE_DIR,
  150. settings=Settings(
  151. anonymized_telemetry=False,
  152. )
  153. )
  154. COLLECTION = client.get_or_create_collection(name=COLLECTION_NAME, embedding_function=sentence_transformer_ef)
  155. log.info(f"Using web root: {WEBDATA_DIR.absolute()}")
  156. app = web.Application(logger=log)
  157. app.add_routes(routes)
  158. app.router.add_static('/css', WEBDATA_DIR / 'css')
  159. app.router.add_static('/fonts', WEBDATA_DIR / 'fonts')
  160. app.router.add_static('/img', WEBDATA_DIR / 'img')
  161. app.router.add_static('/js', WEBDATA_DIR / 'js')
  162. log.info(f'Serving http at http://{HTTP_HOST}:{HTTP_PORT}')
  163. web.run_app(app, host=HTTP_HOST, port=HTTP_PORT, access_log=log,
  164. access_log_format='%a "%r" %s %b "%{Referer}i" "%{User-Agent}i"', print=False)
  165. if __name__ == '__main__':
  166. run()