[fix] lower on query, [add] metadata response, [add] context distance & reference links
This commit is contained in:
parent
05c5546c10
commit
40daf46c03
@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer.
|
|||||||
To normalize data, you can use Minyma's `normalize` CLI command:
|
To normalize data, you can use Minyma's `normalize` CLI command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma
|
minyma normalize \
|
||||||
|
--filename ./pubmed_manuscripts.jsonl \
|
||||||
|
--normalizer pubmed \
|
||||||
|
--database chroma \
|
||||||
|
--datapath ./chroma
|
||||||
```
|
```
|
||||||
|
|
||||||
The above example does the following:
|
The above example does the following:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
|
from os import path
|
||||||
import click
|
import click
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from minyma.config import Config
|
|
||||||
from minyma.oai import OpenAIConnector
|
from minyma.oai import OpenAIConnector
|
||||||
from minyma.vdb import ChromaDB
|
from minyma.vdb import ChromaDB
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
@ -15,14 +15,15 @@ def signal_handler(sig, frame):
|
|||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
global oai, cdb
|
global oai, vdb
|
||||||
|
|
||||||
|
from minyma.config import Config
|
||||||
import minyma.api.common as api_common
|
import minyma.api.common as api_common
|
||||||
import minyma.api.v1 as api_v1
|
import minyma.api.v1 as api_v1
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
cdb = ChromaDB(Config.DATA_PATH)
|
vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma"))
|
||||||
oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb)
|
oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb)
|
||||||
|
|
||||||
app.register_blueprint(api_common.bp)
|
app.register_blueprint(api_common.bp)
|
||||||
app.register_blueprint(api_v1.bp)
|
app.register_blueprint(api_v1.bp)
|
||||||
|
@ -17,8 +17,28 @@ def get_response():
|
|||||||
if message == "":
|
if message == "":
|
||||||
return {"error": "Empty Message"}
|
return {"error": "Empty Message"}
|
||||||
|
|
||||||
oai_response = minyma.oai.query(message)
|
resp = minyma.oai.query(message)
|
||||||
return oai_response
|
|
||||||
|
# Derive LLM Data
|
||||||
|
llm_resp = resp.get("llm", {})
|
||||||
|
llm_choices = llm_resp.get("choices", [])
|
||||||
|
|
||||||
|
# Derive VDB Data
|
||||||
|
vdb_resp = resp.get("vdb", {})
|
||||||
|
combined_context = [{
|
||||||
|
"id": vdb_resp.get("ids")[i],
|
||||||
|
"distance": vdb_resp.get("distances")[i],
|
||||||
|
"doc": vdb_resp.get("docs")[i],
|
||||||
|
"metadata": vdb_resp.get("metadatas")[i],
|
||||||
|
} for i, _ in enumerate(vdb_resp.get("docs", []))]
|
||||||
|
|
||||||
|
# Return Data
|
||||||
|
return {
|
||||||
|
"response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"),
|
||||||
|
"context": combined_context,
|
||||||
|
"usage": llm_resp.get("usage"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -34,5 +54,5 @@ def get_related():
|
|||||||
if message == "":
|
if message == "":
|
||||||
return {"error": "Empty Message"}
|
return {"error": "Empty Message"}
|
||||||
|
|
||||||
related_documents = minyma.cdb.get_related(message)
|
related_documents = minyma.vdb.get_related(message)
|
||||||
return related_documents
|
return related_documents
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class DataNormalizer:
|
class DataNormalizer():
|
||||||
def __init__(self, file: TextIOWrapper):
|
def __init__(self, file: TextIOWrapper):
|
||||||
pass
|
self.file = file
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
pass
|
yield None
|
||||||
|
|
||||||
|
|
||||||
class PubMedNormalizer(DataNormalizer):
|
class PubMedNormalizer(DataNormalizer):
|
||||||
"""
|
"""
|
||||||
@ -14,7 +18,14 @@ class PubMedNormalizer(DataNormalizer):
|
|||||||
normalized inside the iterator.
|
normalized inside the iterator.
|
||||||
"""
|
"""
|
||||||
def __init__(self, file: TextIOWrapper):
|
def __init__(self, file: TextIOWrapper):
|
||||||
self.file = file
|
self.file = file
|
||||||
|
self.length = 0
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
last_pos = self.file.tell()
|
||||||
|
self.length = sum(1 for _ in self.file)
|
||||||
|
self.file.seek(last_pos)
|
||||||
|
return self.length
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
count = 0
|
count = 0
|
||||||
@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer):
|
|||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
# ID = Line Number
|
# ID = Line Number
|
||||||
yield { "doc": norm_text, "id": str(count - 1) }
|
yield {
|
||||||
|
"id": str(count - 1),
|
||||||
|
"doc": norm_text,
|
||||||
|
"metadata": {
|
||||||
|
"file": l.get("file")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
@ -5,8 +5,8 @@ from minyma.vdb import VectorDB
|
|||||||
|
|
||||||
# Stolen LangChain Prompt
|
# Stolen LangChain Prompt
|
||||||
PROMPT_TEMPLATE = """
|
PROMPT_TEMPLATE = """
|
||||||
Use the following pieces of context to answer the question at the end.
|
Use the following pieces of context to answer the question at the end.
|
||||||
If you don't know the answer, just say that you don't know, don't try to
|
If you don't know the answer, just say that you don't know, don't try to
|
||||||
make up an answer.
|
make up an answer.
|
||||||
|
|
||||||
{context}
|
{context}
|
||||||
@ -19,6 +19,7 @@ class OpenAIConnector:
|
|||||||
def __init__(self, api_key: str, vdb: VectorDB):
|
def __init__(self, api_key: str, vdb: VectorDB):
|
||||||
self.vdb = vdb
|
self.vdb = vdb
|
||||||
self.model = "gpt-3.5-turbo"
|
self.model = "gpt-3.5-turbo"
|
||||||
|
self.word_cap = 1000
|
||||||
openai.api_key = api_key
|
openai.api_key = api_key
|
||||||
|
|
||||||
def query(self, question: str) -> Any:
|
def query(self, question: str) -> Any:
|
||||||
@ -30,8 +31,9 @@ class OpenAIConnector:
|
|||||||
if len(all_docs) == 0:
|
if len(all_docs) == 0:
|
||||||
return { "error": "No Context Found" }
|
return { "error": "No Context Found" }
|
||||||
|
|
||||||
# Join on new line, generate main prompt
|
# Join on new line (cap @ word limit), generate main prompt
|
||||||
context = '\n'.join(all_docs)
|
reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs))
|
||||||
|
context = '\n'.join(reduced_docs)
|
||||||
prompt = PROMPT_TEMPLATE.format(context = context, question = question)
|
prompt = PROMPT_TEMPLATE.format(context = context, question = question)
|
||||||
|
|
||||||
# Query OpenAI ChatCompletion
|
# Query OpenAI ChatCompletion
|
||||||
@ -41,4 +43,4 @@ class OpenAIConnector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Return Response
|
# Return Response
|
||||||
return response
|
return { "llm": response, "vdb": related }
|
||||||
|
@ -72,41 +72,41 @@
|
|||||||
</main>
|
</main>
|
||||||
<script>
|
<script>
|
||||||
const LOADING_SVG = `<svg
|
const LOADING_SVG = `<svg
|
||||||
width="24"
|
width="24"
|
||||||
height="24"
|
height="24"
|
||||||
viewBox="0 0 24 24"
|
viewBox="0 0 24 24"
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
fill="currentColor"
|
fill="currentColor"
|
||||||
>
|
>
|
||||||
<style>
|
<style>
|
||||||
.spinner_qM83 {
|
.spinner_qM83 {
|
||||||
animation: spinner_8HQG 1.05s infinite;
|
animation: spinner_8HQG 1.05s infinite;
|
||||||
}
|
}
|
||||||
.spinner_oXPr {
|
.spinner_oXPr {
|
||||||
animation-delay: 0.1s;
|
animation-delay: 0.1s;
|
||||||
}
|
}
|
||||||
.spinner_ZTLf {
|
.spinner_ZTLf {
|
||||||
animation-delay: 0.2s;
|
animation-delay: 0.2s;
|
||||||
}
|
}
|
||||||
@keyframes spinner_8HQG {
|
@keyframes spinner_8HQG {
|
||||||
0%,
|
0%,
|
||||||
57.14% {
|
57.14% {
|
||||||
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
|
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
|
||||||
transform: translate(0);
|
transform: translate(0);
|
||||||
}
|
}
|
||||||
28.57% {
|
28.57% {
|
||||||
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
|
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
|
||||||
transform: translateY(-6px);
|
transform: translateY(-6px);
|
||||||
}
|
}
|
||||||
100% {
|
100% {
|
||||||
transform: translate(0);
|
transform: translate(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
|
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
|
||||||
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
|
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
|
||||||
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
|
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
|
||||||
</svg>`;
|
</svg>`;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wrapper API Call
|
* Wrapper API Call
|
||||||
@ -125,9 +125,9 @@
|
|||||||
// Wrapping Element
|
// Wrapping Element
|
||||||
let wrapEl = document.createElement("div");
|
let wrapEl = document.createElement("div");
|
||||||
wrapEl.innerHTML = `<div class="flex">
|
wrapEl.innerHTML = `<div class="flex">
|
||||||
<span class="font-bold w-24 grow-0 shrink-0"></span>
|
<span class="font-bold w-24 grow-0 shrink-0"></span>
|
||||||
<span class="whitespace-break-spaces w-full"></span>
|
<span class="whitespace-break-spaces w-full"></span>
|
||||||
</div>`;
|
</div>`;
|
||||||
|
|
||||||
// Get Elements
|
// Get Elements
|
||||||
let nameEl = wrapEl.querySelector("span");
|
let nameEl = wrapEl.querySelector("span");
|
||||||
@ -158,7 +158,63 @@
|
|||||||
})
|
})
|
||||||
.then((data) => {
|
.then((data) => {
|
||||||
console.log("SUCCESS:", data);
|
console.log("SUCCESS:", data);
|
||||||
content.innerText = data.choices[0].message.content;
|
|
||||||
|
// Create Response Element
|
||||||
|
let responseEl = document.createElement("p");
|
||||||
|
responseEl.setAttribute(
|
||||||
|
"class",
|
||||||
|
"whitespace-break-spaces border-b pb-3 mb-3"
|
||||||
|
);
|
||||||
|
responseEl.innerText = data.response;
|
||||||
|
|
||||||
|
// Create Context Element
|
||||||
|
let contextEl = document.createElement("div");
|
||||||
|
contextEl.innerHTML = `
|
||||||
|
<h1 class="font-bold">Context:</h1>
|
||||||
|
<ul class="list-disc ml-6"></ul>`;
|
||||||
|
let ulEl = contextEl.querySelector("ul");
|
||||||
|
|
||||||
|
// Create Context Links
|
||||||
|
data.context
|
||||||
|
|
||||||
|
// Capture PubMed ID & Distance
|
||||||
|
.map((item) => [
|
||||||
|
item.metadata.file.match("\/(.*)\.txt$"),
|
||||||
|
item.distance,
|
||||||
|
])
|
||||||
|
|
||||||
|
// Filter Non-Matches
|
||||||
|
.filter(([match]) => match)
|
||||||
|
|
||||||
|
// Get Match Value & Round Distance (2)
|
||||||
|
.map(([match, distance]) => [
|
||||||
|
match[1],
|
||||||
|
Math.round(distance * 100) / 100,
|
||||||
|
])
|
||||||
|
|
||||||
|
// Create Links
|
||||||
|
.forEach(([pmid, distance]) => {
|
||||||
|
let newEl = document.createElement("li");
|
||||||
|
let linkEl = document.createElement("a");
|
||||||
|
linkEl.setAttribute("target", "_blank");
|
||||||
|
linkEl.setAttribute(
|
||||||
|
"class",
|
||||||
|
"text-blue-500 hover:text-blue-600"
|
||||||
|
);
|
||||||
|
linkEl.setAttribute(
|
||||||
|
"href",
|
||||||
|
"https://www.ncbi.nlm.nih.gov/pmc/articles/" + pmid
|
||||||
|
);
|
||||||
|
linkEl.textContent = "[" + distance + "] " + pmid;
|
||||||
|
newEl.append(linkEl);
|
||||||
|
ulEl.append(newEl);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add to DOM
|
||||||
|
content.setAttribute("class", "w-full");
|
||||||
|
content.innerHTML = "";
|
||||||
|
content.append(responseEl);
|
||||||
|
content.append(contextEl);
|
||||||
})
|
})
|
||||||
.catch((e) => {
|
.catch((e) => {
|
||||||
console.log("ERROR:", e);
|
console.log("ERROR:", e);
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from chromadb.api import API
|
from chromadb.api import API
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from os import path
|
from tqdm import tqdm
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
@ -29,47 +28,45 @@ class VectorDB:
|
|||||||
ChromaDV VectorDB Type
|
ChromaDV VectorDB Type
|
||||||
"""
|
"""
|
||||||
class ChromaDB(VectorDB):
|
class ChromaDB(VectorDB):
|
||||||
def __init__(self, base_path: str):
|
def __init__(self, path: str):
|
||||||
chroma_path = path.join(base_path, "chroma")
|
self.client: API = chromadb.PersistentClient(path=path)
|
||||||
self.client: API = chromadb.PersistentClient(path=chroma_path)
|
self.word_cap = 2500
|
||||||
self.word_limit = 1000
|
|
||||||
self.collection_name: str = "vdb"
|
self.collection_name: str = "vdb"
|
||||||
self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
|
self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
|
||||||
|
|
||||||
def get_related(self, question) -> Any:
|
def get_related(self, question: str) -> Any:
|
||||||
"""Returns line separated related docs"""
|
"""Returns line separated related docs"""
|
||||||
results = self.collection.query(
|
results = self.collection.query(
|
||||||
query_texts=[question],
|
query_texts=[question.lower()],
|
||||||
n_results=2
|
n_results=2
|
||||||
)
|
)
|
||||||
|
|
||||||
all_docs: list = cast(list, results.get("documents", [[]]))[0]
|
all_docs: list = cast(list, results.get("documents", [[]]))[0]
|
||||||
|
all_metadata: list = cast(list, results.get("metadatas", [[]]))[0]
|
||||||
all_distances: list = cast(list, results.get("distances", [[]]))[0]
|
all_distances: list = cast(list, results.get("distances", [[]]))[0]
|
||||||
all_ids: list = cast(list, results.get("ids", [[]]))[0]
|
all_ids: list = cast(list, results.get("ids", [[]]))[0]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"distances":all_distances,
|
"distances": all_distances,
|
||||||
|
"metadatas": all_metadata,
|
||||||
"docs": all_docs,
|
"docs": all_docs,
|
||||||
"ids": all_ids
|
"ids": all_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_documents(self, normalizer: DataNormalizer):
|
def load_documents(self, normalizer: DataNormalizer, chunk_size: int = 10):
|
||||||
# 10 Item Chunking
|
length = len(normalizer) / chunk_size
|
||||||
for items in tqdm(chunk(normalizer, 50)):
|
for items in tqdm(chunk(normalizer, chunk_size), total=length):
|
||||||
ids = []
|
ids = []
|
||||||
documents = []
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
# Limit words per document to accommodate context token limits
|
|
||||||
for item in items:
|
for item in items:
|
||||||
doc = " ".join(item.get("doc").split()[:self.word_limit])
|
documents.append(" ".join(item.get("doc").split()[:self.word_cap]))
|
||||||
documents.append(doc)
|
|
||||||
ids.append(item.get("id"))
|
ids.append(item.get("id"))
|
||||||
|
metadatas.append(item.get("metadata", {}))
|
||||||
|
|
||||||
# Ideally we parse out metadata from each document
|
|
||||||
# and pass to the metadata kwarg. However, each
|
|
||||||
# document appears to have a slightly different format,
|
|
||||||
# so it's difficult to parse out.
|
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
|
ids=ids,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
ids=ids
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user