[fix] lower on query, [add] metadata response, [add] context distance & reference links

This commit is contained in:
Evan Reichard 2023-10-19 18:56:48 -04:00
parent 05c5546c10
commit 40daf46c03
7 changed files with 174 additions and 77 deletions

View File

@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer.
To normalize data, you can use Minyma's `normalize` CLI command:
```bash
minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma
minyma normalize \
--filename ./pubmed_manuscripts.jsonl \
--normalizer pubmed \
--database chroma \
--datapath ./chroma
```
The above example does the following:

View File

@ -1,8 +1,8 @@
from os import path
import click
import signal
import sys
from importlib.metadata import version
from minyma.config import Config
from minyma.oai import OpenAIConnector
from minyma.vdb import ChromaDB
from flask import Flask
@ -15,14 +15,15 @@ def signal_handler(sig, frame):
def create_app():
global oai, cdb
global oai, vdb
from minyma.config import Config
import minyma.api.common as api_common
import minyma.api.v1 as api_v1
app = Flask(__name__)
cdb = ChromaDB(Config.DATA_PATH)
oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb)
vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma"))
oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb)
app.register_blueprint(api_common.bp)
app.register_blueprint(api_v1.bp)

View File

@ -17,8 +17,28 @@ def get_response():
if message == "":
return {"error": "Empty Message"}
oai_response = minyma.oai.query(message)
return oai_response
resp = minyma.oai.query(message)
# Derive LLM Data
llm_resp = resp.get("llm", {})
llm_choices = llm_resp.get("choices", [])
# Derive VDB Data
vdb_resp = resp.get("vdb", {})
combined_context = [{
"id": vdb_resp.get("ids")[i],
"distance": vdb_resp.get("distances")[i],
"doc": vdb_resp.get("docs")[i],
"metadata": vdb_resp.get("metadatas")[i],
} for i, _ in enumerate(vdb_resp.get("docs", []))]
# Return Data
return {
"response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"),
"context": combined_context,
"usage": llm_resp.get("usage"),
}
"""
@ -34,5 +54,5 @@ def get_related():
if message == "":
return {"error": "Empty Message"}
related_documents = minyma.cdb.get_related(message)
related_documents = minyma.vdb.get_related(message)
return related_documents

View File

@ -1,12 +1,16 @@
from io import TextIOWrapper
import json
class DataNormalizer:
class DataNormalizer():
def __init__(self, file: TextIOWrapper):
pass
self.file = file
def __len__(self) -> int:
return 0
def __iter__(self):
pass
yield None
class PubMedNormalizer(DataNormalizer):
"""
@ -14,7 +18,14 @@ class PubMedNormalizer(DataNormalizer):
normalized inside the iterator.
"""
def __init__(self, file: TextIOWrapper):
self.file = file
self.file = file
self.length = 0
def __len__(self):
last_pos = self.file.tell()
self.length = sum(1 for _ in self.file)
self.file.seek(last_pos)
return self.length
def __iter__(self):
count = 0
@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer):
count += 1
# ID = Line Number
yield { "doc": norm_text, "id": str(count - 1) }
yield {
"id": str(count - 1),
"doc": norm_text,
"metadata": {
"file": l.get("file")
},
}

View File

@ -19,6 +19,7 @@ class OpenAIConnector:
def __init__(self, api_key: str, vdb: VectorDB):
self.vdb = vdb
self.model = "gpt-3.5-turbo"
self.word_cap = 1000
openai.api_key = api_key
def query(self, question: str) -> Any:
@ -30,8 +31,9 @@ class OpenAIConnector:
if len(all_docs) == 0:
return { "error": "No Context Found" }
# Join on new line, generate main prompt
context = '\n'.join(all_docs)
# Join on new line (cap @ word limit), generate main prompt
reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs))
context = '\n'.join(reduced_docs)
prompt = PROMPT_TEMPLATE.format(context = context, question = question)
# Query OpenAI ChatCompletion
@ -41,4 +43,4 @@ class OpenAIConnector:
)
# Return Response
return response
return { "llm": response, "vdb": related }

View File

@ -72,41 +72,41 @@
</main>
<script>
const LOADING_SVG = `<svg
width="24"
height="24"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
>
<style>
.spinner_qM83 {
animation: spinner_8HQG 1.05s infinite;
}
.spinner_oXPr {
animation-delay: 0.1s;
}
.spinner_ZTLf {
animation-delay: 0.2s;
}
@keyframes spinner_8HQG {
0%,
57.14% {
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
transform: translate(0);
}
28.57% {
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
transform: translateY(-6px);
}
100% {
transform: translate(0);
}
}
</style>
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
</svg>`;
width="24"
height="24"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
>
<style>
.spinner_qM83 {
animation: spinner_8HQG 1.05s infinite;
}
.spinner_oXPr {
animation-delay: 0.1s;
}
.spinner_ZTLf {
animation-delay: 0.2s;
}
@keyframes spinner_8HQG {
0%,
57.14% {
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
transform: translate(0);
}
28.57% {
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
transform: translateY(-6px);
}
100% {
transform: translate(0);
}
}
</style>
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
</svg>`;
/**
* Wrapper API Call
@ -125,9 +125,9 @@
// Wrapping Element
let wrapEl = document.createElement("div");
wrapEl.innerHTML = `<div class="flex">
<span class="font-bold w-24 grow-0 shrink-0"></span>
<span class="whitespace-break-spaces w-full"></span>
</div>`;
<span class="font-bold w-24 grow-0 shrink-0"></span>
<span class="whitespace-break-spaces w-full"></span>
</div>`;
// Get Elements
let nameEl = wrapEl.querySelector("span");
@ -158,7 +158,63 @@
})
.then((data) => {
console.log("SUCCESS:", data);
content.innerText = data.choices[0].message.content;
// Create Response Element
let responseEl = document.createElement("p");
responseEl.setAttribute(
"class",
"whitespace-break-spaces border-b pb-3 mb-3"
);
responseEl.innerText = data.response;
// Create Context Element
let contextEl = document.createElement("div");
contextEl.innerHTML = `
<h1 class="font-bold">Context:</h1>
<ul class="list-disc ml-6"></ul>`;
let ulEl = contextEl.querySelector("ul");
// Create Context Links
data.context
// Capture PubMed ID & Distance
.map((item) => [
item.metadata.file.match("\/(.*)\.txt$"),
item.distance,
])
// Filter Non-Matches
.filter(([match]) => match)
// Get Match Value & Round Distance (2)
.map(([match, distance]) => [
match[1],
Math.round(distance * 100) / 100,
])
// Create Links
.forEach(([pmid, distance]) => {
let newEl = document.createElement("li");
let linkEl = document.createElement("a");
linkEl.setAttribute("target", "_blank");
linkEl.setAttribute(
"class",
"text-blue-500 hover:text-blue-600"
);
linkEl.setAttribute(
"href",
"https://www.ncbi.nlm.nih.gov/pmc/articles/" + pmid
);
linkEl.textContent = "[" + distance + "] " + pmid;
newEl.append(linkEl);
ulEl.append(newEl);
});
// Add to DOM
content.setAttribute("class", "w-full");
content.innerHTML = "";
content.append(responseEl);
content.append(contextEl);
})
.catch((e) => {
console.log("ERROR:", e);

View File

@ -1,7 +1,6 @@
from chromadb.api import API
from itertools import islice
from os import path
from tqdm.auto import tqdm
from tqdm import tqdm
from typing import Any, cast
import chromadb
@ -29,47 +28,45 @@ class VectorDB:
ChromaDV VectorDB Type
"""
class ChromaDB(VectorDB):
def __init__(self, base_path: str):
chroma_path = path.join(base_path, "chroma")
self.client: API = chromadb.PersistentClient(path=chroma_path)
self.word_limit = 1000
def __init__(self, path: str):
self.client: API = chromadb.PersistentClient(path=path)
self.word_cap = 2500
self.collection_name: str = "vdb"
self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
def get_related(self, question) -> Any:
def get_related(self, question: str) -> Any:
"""Returns line separated related docs"""
results = self.collection.query(
query_texts=[question],
query_texts=[question.lower()],
n_results=2
)
all_docs: list = cast(list, results.get("documents", [[]]))[0]
all_metadata: list = cast(list, results.get("metadatas", [[]]))[0]
all_distances: list = cast(list, results.get("distances", [[]]))[0]
all_ids: list = cast(list, results.get("ids", [[]]))[0]
return {
"distances":all_distances,
"distances": all_distances,
"metadatas": all_metadata,
"docs": all_docs,
"ids": all_ids
}
def load_documents(self, normalizer: DataNormalizer):
# 10 Item Chunking
for items in tqdm(chunk(normalizer, 50)):
def load_documents(self, normalizer: DataNormalizer, chunk_size: int = 10):
length = len(normalizer) / chunk_size
for items in tqdm(chunk(normalizer, chunk_size), total=length):
ids = []
documents = []
metadatas = []
# Limit words per document to accommodate context token limits
for item in items:
doc = " ".join(item.get("doc").split()[:self.word_limit])
documents.append(doc)
documents.append(" ".join(item.get("doc").split()[:self.word_cap]))
ids.append(item.get("id"))
metadatas.append(item.get("metadata", {}))
# Ideally we parse out metadata from each document
# and pass to the metadata kwarg. However, each
# document appears to have a slightly different format,
# so it's difficult to parse out.
self.collection.add(
ids=ids,
documents=documents,
ids=ids
metadatas=metadatas,
)