[fix] lower on query, [add] metadata response, [add] context distance & reference links

This commit is contained in:
Evan Reichard 2023-10-19 18:56:48 -04:00
parent 05c5546c10
commit 40daf46c03
7 changed files with 174 additions and 77 deletions

View File

@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer.
To normalize data, you can use Minyma's `normalize` CLI command: To normalize data, you can use Minyma's `normalize` CLI command:
```bash ```bash
minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma minyma normalize \
--filename ./pubmed_manuscripts.jsonl \
--normalizer pubmed \
--database chroma \
--datapath ./chroma
``` ```
The above example does the following: The above example does the following:

View File

@ -1,8 +1,8 @@
from os import path
import click import click
import signal import signal
import sys import sys
from importlib.metadata import version from importlib.metadata import version
from minyma.config import Config
from minyma.oai import OpenAIConnector from minyma.oai import OpenAIConnector
from minyma.vdb import ChromaDB from minyma.vdb import ChromaDB
from flask import Flask from flask import Flask
@ -15,14 +15,15 @@ def signal_handler(sig, frame):
def create_app(): def create_app():
global oai, cdb global oai, vdb
from minyma.config import Config
import minyma.api.common as api_common import minyma.api.common as api_common
import minyma.api.v1 as api_v1 import minyma.api.v1 as api_v1
app = Flask(__name__) app = Flask(__name__)
cdb = ChromaDB(Config.DATA_PATH) vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma"))
oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb) oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb)
app.register_blueprint(api_common.bp) app.register_blueprint(api_common.bp)
app.register_blueprint(api_v1.bp) app.register_blueprint(api_v1.bp)

View File

@ -17,8 +17,28 @@ def get_response():
if message == "": if message == "":
return {"error": "Empty Message"} return {"error": "Empty Message"}
oai_response = minyma.oai.query(message) resp = minyma.oai.query(message)
return oai_response
# Derive LLM Data
llm_resp = resp.get("llm", {})
llm_choices = llm_resp.get("choices", [])
# Derive VDB Data
vdb_resp = resp.get("vdb", {})
combined_context = [{
"id": vdb_resp.get("ids")[i],
"distance": vdb_resp.get("distances")[i],
"doc": vdb_resp.get("docs")[i],
"metadata": vdb_resp.get("metadatas")[i],
} for i, _ in enumerate(vdb_resp.get("docs", []))]
# Return Data
return {
"response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"),
"context": combined_context,
"usage": llm_resp.get("usage"),
}
""" """
@ -34,5 +54,5 @@ def get_related():
if message == "": if message == "":
return {"error": "Empty Message"} return {"error": "Empty Message"}
related_documents = minyma.cdb.get_related(message) related_documents = minyma.vdb.get_related(message)
return related_documents return related_documents

View File

@ -1,12 +1,16 @@
from io import TextIOWrapper from io import TextIOWrapper
import json import json
class DataNormalizer: class DataNormalizer():
def __init__(self, file: TextIOWrapper): def __init__(self, file: TextIOWrapper):
pass self.file = file
def __len__(self) -> int:
return 0
def __iter__(self): def __iter__(self):
pass yield None
class PubMedNormalizer(DataNormalizer): class PubMedNormalizer(DataNormalizer):
""" """
@ -15,6 +19,13 @@ class PubMedNormalizer(DataNormalizer):
""" """
def __init__(self, file: TextIOWrapper): def __init__(self, file: TextIOWrapper):
self.file = file self.file = file
self.length = 0
def __len__(self):
last_pos = self.file.tell()
self.length = sum(1 for _ in self.file)
self.file.seek(last_pos)
return self.length
def __iter__(self): def __iter__(self):
count = 0 count = 0
@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer):
count += 1 count += 1
# ID = Line Number # ID = Line Number
yield { "doc": norm_text, "id": str(count - 1) } yield {
"id": str(count - 1),
"doc": norm_text,
"metadata": {
"file": l.get("file")
},
}

View File

@ -19,6 +19,7 @@ class OpenAIConnector:
def __init__(self, api_key: str, vdb: VectorDB): def __init__(self, api_key: str, vdb: VectorDB):
self.vdb = vdb self.vdb = vdb
self.model = "gpt-3.5-turbo" self.model = "gpt-3.5-turbo"
self.word_cap = 1000
openai.api_key = api_key openai.api_key = api_key
def query(self, question: str) -> Any: def query(self, question: str) -> Any:
@ -30,8 +31,9 @@ class OpenAIConnector:
if len(all_docs) == 0: if len(all_docs) == 0:
return { "error": "No Context Found" } return { "error": "No Context Found" }
# Join on new line, generate main prompt # Join on new line (cap @ word limit), generate main prompt
context = '\n'.join(all_docs) reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs))
context = '\n'.join(reduced_docs)
prompt = PROMPT_TEMPLATE.format(context = context, question = question) prompt = PROMPT_TEMPLATE.format(context = context, question = question)
# Query OpenAI ChatCompletion # Query OpenAI ChatCompletion
@ -41,4 +43,4 @@ class OpenAIConnector:
) )
# Return Response # Return Response
return response return { "llm": response, "vdb": related }

View File

@ -158,7 +158,63 @@
}) })
.then((data) => { .then((data) => {
console.log("SUCCESS:", data); console.log("SUCCESS:", data);
content.innerText = data.choices[0].message.content;
// Create Response Element
let responseEl = document.createElement("p");
responseEl.setAttribute(
"class",
"whitespace-break-spaces border-b pb-3 mb-3"
);
responseEl.innerText = data.response;
// Create Context Element
let contextEl = document.createElement("div");
contextEl.innerHTML = `
<h1 class="font-bold">Context:</h1>
<ul class="list-disc ml-6"></ul>`;
let ulEl = contextEl.querySelector("ul");
// Create Context Links
data.context
// Capture PubMed ID & Distance
.map((item) => [
item.metadata.file.match("\/(.*)\.txt$"),
item.distance,
])
// Filter Non-Matches
.filter(([match]) => match)
// Get Match Value & Round Distance (2)
.map(([match, distance]) => [
match[1],
Math.round(distance * 100) / 100,
])
// Create Links
.forEach(([pmid, distance]) => {
let newEl = document.createElement("li");
let linkEl = document.createElement("a");
linkEl.setAttribute("target", "_blank");
linkEl.setAttribute(
"class",
"text-blue-500 hover:text-blue-600"
);
linkEl.setAttribute(
"href",
"https://www.ncbi.nlm.nih.gov/pmc/articles/" + pmid
);
linkEl.textContent = "[" + distance + "] " + pmid;
newEl.append(linkEl);
ulEl.append(newEl);
});
// Add to DOM
content.setAttribute("class", "w-full");
content.innerHTML = "";
content.append(responseEl);
content.append(contextEl);
}) })
.catch((e) => { .catch((e) => {
console.log("ERROR:", e); console.log("ERROR:", e);

View File

@ -1,7 +1,6 @@
from chromadb.api import API from chromadb.api import API
from itertools import islice from itertools import islice
from os import path from tqdm import tqdm
from tqdm.auto import tqdm
from typing import Any, cast from typing import Any, cast
import chromadb import chromadb
@ -29,47 +28,45 @@ class VectorDB:
ChromaDV VectorDB Type ChromaDV VectorDB Type
""" """
class ChromaDB(VectorDB): class ChromaDB(VectorDB):
def __init__(self, base_path: str): def __init__(self, path: str):
chroma_path = path.join(base_path, "chroma") self.client: API = chromadb.PersistentClient(path=path)
self.client: API = chromadb.PersistentClient(path=chroma_path) self.word_cap = 2500
self.word_limit = 1000
self.collection_name: str = "vdb" self.collection_name: str = "vdb"
self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True) self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
def get_related(self, question) -> Any: def get_related(self, question: str) -> Any:
"""Returns line separated related docs""" """Returns line separated related docs"""
results = self.collection.query( results = self.collection.query(
query_texts=[question], query_texts=[question.lower()],
n_results=2 n_results=2
) )
all_docs: list = cast(list, results.get("documents", [[]]))[0] all_docs: list = cast(list, results.get("documents", [[]]))[0]
all_metadata: list = cast(list, results.get("metadatas", [[]]))[0]
all_distances: list = cast(list, results.get("distances", [[]]))[0] all_distances: list = cast(list, results.get("distances", [[]]))[0]
all_ids: list = cast(list, results.get("ids", [[]]))[0] all_ids: list = cast(list, results.get("ids", [[]]))[0]
return { return {
"distances":all_distances, "distances": all_distances,
"metadatas": all_metadata,
"docs": all_docs, "docs": all_docs,
"ids": all_ids "ids": all_ids
} }
def load_documents(self, normalizer: DataNormalizer): def load_documents(self, normalizer: DataNormalizer, chunk_size: int = 10):
# 10 Item Chunking length = len(normalizer) / chunk_size
for items in tqdm(chunk(normalizer, 50)): for items in tqdm(chunk(normalizer, chunk_size), total=length):
ids = [] ids = []
documents = [] documents = []
metadatas = []
# Limit words per document to accommodate context token limits
for item in items: for item in items:
doc = " ".join(item.get("doc").split()[:self.word_limit]) documents.append(" ".join(item.get("doc").split()[:self.word_cap]))
documents.append(doc)
ids.append(item.get("id")) ids.append(item.get("id"))
metadatas.append(item.get("metadata", {}))
# Ideally we parse out metadata from each document
# and pass to the metadata kwarg. However, each
# document appears to have a slightly different format,
# so it's difficult to parse out.
self.collection.add( self.collection.add(
ids=ids,
documents=documents, documents=documents,
ids=ids metadatas=metadatas,
) )