diff --git a/python/mall/llm.py b/python/mall/llm.py index efed256..41c3678 100644 --- a/python/mall/llm.py +++ b/python/mall/llm.py @@ -3,6 +3,7 @@ import polars as pl import hashlib import ollama +import copy import json import os @@ -10,8 +11,9 @@ def llm_use(backend="", model="", _cache="_mall_cache", **kwargs): out = dict() if isinstance(backend, Chat): + chat_copy = copy.deepcopy(backend) out.update(dict(backend="chatlas")) - out.update(dict(chat=backend)) + out.update(dict(chat=chat_copy)) backend = "" model = "" if isinstance(backend, Client): @@ -42,6 +44,8 @@ def llm_map(df, col, msg, pred_name, use, valid_resps="", convert=None): pl_type = pl.Int8 data_type = int + use = llm_init_use(use, msg) + df = df.with_columns( pl.col(col) .map_elements( @@ -61,9 +65,10 @@ def llm_map(df, col, msg, pred_name, use, valid_resps="", convert=None): def llm_loop(x, msg, use, valid_resps="", convert=None): - if isinstance(x, list) == False: + if not isinstance(x, list): raise TypeError("`x` is not a list object") out = list() + use = llm_init_use(use, msg) for row in x: out.append( llm_call(x=row, msg=msg, use=use, valid_resps=valid_resps, convert=convert) @@ -71,8 +76,17 @@ def llm_loop(x, msg, use, valid_resps="", convert=None): return out -def llm_call(x, msg, use, valid_resps="", convert=None, data_type=None): +def llm_init_use(use, msg): + backend = use.get("backend") + if backend == "chatlas": + chat = use.get("chat") + chat.set_turns(list()) + chat.system_prompt = msg + use.update(chat=chat) + return use + +def llm_call(x, msg, use, valid_resps="", convert=None, data_type=None): backend = use.get("backend") model = use.get("model") call = dict( @@ -84,15 +98,13 @@ def llm_call(x, msg, use, valid_resps="", convert=None, data_type=None): out = "" cache = "" if use.get("_cache") != "": - hash_call = build_hash(call) cache = cache_check(hash_call, use) if cache == "": if backend == "chatlas": chat = use.get("chat") - ch = chat.chat(msg[0].get("content") + x, echo="none") + ch = chat.chat(x, echo="none") out = ch.get_content() - chat.set_turns(list()) if backend == "ollama" or backend == "ollama-client": if backend == "ollama": chat_fun = ollama.chat @@ -109,7 +121,7 @@ def llm_call(x, msg, use, valid_resps="", convert=None, data_type=None): if model == "echo": out = x if model == "content": - out = msg[0]["content"] + out = msg return out else: out = cache @@ -143,10 +155,7 @@ def valid_output(x): def build_msg(x, msg): - out = [] - for msgs in msg: - out.append({"role": msgs["role"], "content": msgs["content"].format(x)}) - return out + return {'role': 'user', 'content': msg + str(x)} def build_hash(x): diff --git a/python/mall/llmvec.py b/python/mall/llmvec.py index 04ec047..88a9ecc 100644 --- a/python/mall/llmvec.py +++ b/python/mall/llmvec.py @@ -19,10 +19,11 @@ class LLMVec: from mall import LLMVec chat = ChatOllama(model = "llama3.2") - - llm = LLMVec(chat) + + llm = LLMVec(chat) ``` """ + def __init__(self, backend="", model="", _cache="_mall_cache", **kwargs): self._use = llm_use(backend=backend, model=model, _cache=_cache, **kwargs) @@ -49,10 +50,10 @@ def sentiment( ```{python} llm.sentiment(['I am happy', 'I am sad']) ``` - """ + """ return llm_loop( x=x, - msg=sentiment(options, additional=additional), + msg=sentiment(options, additional=additional, use=self._use), use=self._use, valid_resps=options, ) @@ -77,10 +78,10 @@ def summarize(self, x, max_words=10, additional="") -> list: ```{python} llm.summarize(['This has been the best TV Ive ever used. Great screen, and sound.'], max_words = 5) ``` - """ + """ return llm_loop( x=x, - msg=summarize(max_words, additional=additional), + msg=summarize(max_words, additional=additional, use=self._use), use=self._use, ) @@ -106,10 +107,10 @@ def translate(self, x, language="", additional="") -> list: llm.translate(['This has been the best TV Ive ever used. Great screen, and sound.'], language = 'spanish') ``` - """ + """ return llm_loop( x=x, - msg=translate(language, additional=additional), + msg=translate(language, additional=additional, use=self._use), use=self._use, ) @@ -135,10 +136,10 @@ def classify(self, x, labels="", additional="") -> list: ```{python} llm.classify(['this is important!', 'there is no rush'], ['urgent', 'not urgent']) ``` - """ + """ return llm_loop( x=x, - msg=classify(labels, additional=additional), + msg=classify(labels, additional=additional, use=self._use), use=self._use, valid_resps=labels, ) @@ -164,8 +165,12 @@ def extract(self, x, labels="", additional="") -> list: ```{python} llm.extract(['bob smith, 123 3rd street'], labels=['name', 'address']) ``` - """ - return llm_loop(x=x, msg=extract(labels, additional=additional), use=self._use) + """ + return llm_loop( + x=x, + msg=extract(labels, additional=additional, use=self._use), + use=self._use, + ) def custom(self, x, prompt="", valid_resps="") -> list: """Provide the full prompt that the LLM will process. @@ -178,7 +183,7 @@ def custom(self, x, prompt="", valid_resps="") -> list: prompt : str The prompt to send to the LLM along with the `col` - """ + """ return llm_loop(x=x, msg=custom(prompt), use=self._use, valid_resps=valid_resps) def verify(self, x, what="", yes_no=[1, 0], additional="") -> list: @@ -201,10 +206,10 @@ def verify(self, x, what="", yes_no=[1, 0], additional="") -> list: additional : str Inserts this text into the prompt sent to the LLM - """ + """ return llm_loop( x=x, - msg=verify(what, additional=additional), + msg=verify(what, additional=additional, use=self._use), use=self._use, valid_resps=yes_no, convert=dict(yes=yes_no[0], no=yes_no[1]), diff --git a/python/mall/polars.py b/python/mall/polars.py index d7ad452..d4cb725 100644 --- a/python/mall/polars.py +++ b/python/mall/polars.py @@ -150,7 +150,7 @@ def sentiment( df = llm_map( df=self._df, col=col, - msg=sentiment(options, additional=additional), + msg=sentiment(options, additional=additional, use=self._use), pred_name=pred_name, use=self._use, valid_resps=options, @@ -197,7 +197,7 @@ def summarize( df = llm_map( df=self._df, col=col, - msg=summarize(max_words, additional=additional), + msg=summarize(max_words, additional=additional, use=self._use), pred_name=pred_name, use=self._use, ) @@ -243,7 +243,7 @@ def translate( df = llm_map( df=self._df, col=col, - msg=translate(language, additional=additional), + msg=translate(language, additional=additional, use=self._use), pred_name=pred_name, use=self._use, ) @@ -295,7 +295,7 @@ def classify( df = llm_map( df=self._df, col=col, - msg=classify(labels, additional=additional), + msg=classify(labels, additional=additional, use=self._use), pred_name=pred_name, use=self._use, valid_resps=labels, @@ -379,7 +379,7 @@ def extract( df = llm_map( df=self._df, col=col, - msg=extract(lab_vals, additional=additional), + msg=extract(lab_vals, additional=additional, use=self._use), pred_name=pred_name, use=self._use, ) @@ -484,7 +484,7 @@ def verify( df = llm_map( df=self._df, col=col, - msg=verify(what, additional=additional), + msg=verify(what, additional=additional, use=self._use), pred_name=pred_name, use=self._use, valid_resps=yes_no, diff --git a/python/mall/prompt.py b/python/mall/prompt.py index d477813..968f298 100644 --- a/python/mall/prompt.py +++ b/python/mall/prompt.py @@ -1,72 +1,54 @@ -def sentiment(options, additional=""): +def sentiment(options, additional="", use=[]): new_options = process_labels( options, "Return only one of the following answers: {values} ", "- If the text is {key}, return {value} ", ) - msg = [ - { - "role": "user", - "content": "You are a helpful sentiment engine. " - + f"{new_options}. " - + "No capitalization. No explanations. " - + f"{additional} " - + "The answer is based on the following text:\n{}", - } - ] - return msg - - -def summarize(max_words, additional=""): - msg = [ - { - "role": "user", - "content": "You are a helpful summarization engine. " - + "Your answer will contain no no capitalization and no explanations. " - + f"Return no more than " - + str(max_words) - + " words. " - + f" {additional} " - + "The answer is the summary of the following text:\n{}", - } - ] - return msg - - -def translate(language, additional=""): - msg = [ - { - "role": "user", - "content": "You are a helpful translation engine. " - + "You will return only the translation text, no explanations. " - + f"The target language to translate to is: {language}. " - + f" {additional} " - + "The answer is the translation of the following text:\n{}", - } - ] - return msg - - -def classify(labels, additional=""): - new_labels = process_labels( + x = ( + "You are a helpful sentiment engine." + f" {new_options}. " + " No capitalization. No explanations." + f" {additional}" + ) + return prompt_complete(x, use) + + +def summarize(max_words, additional="", use=[]): + x = ( + "You are a helpful summarization engine. " + "Your answer will contain no capitalization and no explanations. " + f"Return no more than {max_words} words. " + f"{additional}" + ) + return prompt_complete(x, use) + + +def translate(language, additional="", use=[]): + x = ( + "You are a helpful translation engine. " + "You will return only the translation text, no explanations. " + f"The target language to translate to is: {language}. " + f"{additional}" + ) + return prompt_complete(x, use) + + +def classify(labels, additional="", use=[]): + labels = process_labels( labels, "Determine if the text refers to one of the following:{values} ", "- If the text is {key}, return {value} ", ) - msg = [ - { - "role": "user", - "content": "You are a helpful classification engine. " - + f"{new_labels}. " - + "No capitalization. No explanations. " - + f"{additional} " - + "The answer is based on the following text:\n{}", - } - ] - return msg - - -def extract(labels, additional=""): + x = ( + "You are a helpful classification engine. " + f"{labels}. " + "No capitalization. No explanations. " + f"{additional}" + ) + return prompt_complete(x, use) + + +def extract(labels, additional="", use=[]): col_labels = "" if isinstance(labels, list): no_labels = len(labels) @@ -84,39 +66,30 @@ def extract(labels, additional=""): text_multi = "" col_labels = labels - msg = [ - { - "role": "user", - "content": "You are a helpful text extraction engine. " - + f"Extract the {col_labels} being referred to on the text. " - + f"I expect {no_labels} item{plural} exactly. " - + "No capitalization. No explanations. " - + f" {text_multi} " - + f" {additional} " - + "The answer is based on the following text:\n{}", - } - ] - return msg - - -def verify(what, additional=""): - msg = [ - { - "role": "user", - "content": "You are a helpful text analysis engine. " - + "Determine this is true " - + f"'{what}'." - + "No capitalization. No explanations. " - + f"{additional} " - + "The answer is based on the following text:\n{}", - } - ] - return msg + x = ( + "You are a helpful text extraction engine. " + f"Extract the {col_labels} being referred to in the text. " + f"I expect {no_labels} item{plural} exactly. " + "No capitalization. No explanations. " + f"{text_multi}" + f"{additional}" + ) + return prompt_complete(x, use) + + +def verify(what, additional="", use=[]): + x = ( + "You are a helpful text analysis engine." + "Determine if this is true " + f"'{what}'." + "No capitalization. No explanations." + f"{additional}" + ) + return prompt_complete(x, use) def custom(prompt): - msg = [{"role": "user", "content": f"{prompt}" + ": \n{}"}] - return msg + return prompt def process_labels(x, if_list="", if_dict=""): @@ -135,3 +108,16 @@ def process_labels(x, if_list="", if_dict=""): new = new.replace("{value}", str(x.get(i))) out += " " + new return out + + +def prompt_complete(x, use): + backend = use.get("backend") + if backend == "chatlas": + x = ( + x + + "The answer will be based on each individual prompt." + + " Treat each prompt as unique when deciding the answer." + ) + else: + x = x + "The answer is based on the following text:\n{}" + return x diff --git a/python/pyproject.toml b/python/pyproject.toml index f4b4d2d..04e1b48 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,7 +3,7 @@ packages = ["mall"] [project] name = "mlverse-mall" -version = "0.1.0.9003" +version = "0.1.0.9004" description = "Run multiple 'Large Language Model' predictions against a table. The predictions run row-wise over a specified column." readme = "README.md" authors = [ diff --git a/python/tests/test_custom.py b/python/tests/test_custom.py index efca829..33ba7f8 100644 --- a/python/tests/test_custom.py +++ b/python/tests/test_custom.py @@ -9,7 +9,7 @@ def test_custom_prompt(): df = pl.DataFrame(dict(x="x")) df.llm.use("test", "content", _cache="_test_cache") x = df.llm.custom("x", "hello") - assert x["custom"][0] == "hello: \n{}" + assert x["custom"][0] == "hello" shutil.rmtree("_test_cache", ignore_errors=True) def test_custom_vec(): diff --git a/python/tests/test_extract.py b/python/tests/test_extract.py index 9db3f86..3d99524 100644 --- a/python/tests/test_extract.py +++ b/python/tests/test_extract.py @@ -11,7 +11,7 @@ def test_extract_list(): x = df.llm.extract("x", ["a", "b"]) assert ( x["extract"][0] - == "You are a helpful text extraction engine. Extract the a, b being referred to on the text. I expect 2 items exactly. No capitalization. No explanations. Return the response exclusively in a pipe separated list, and no headers. The answer is based on the following text:\n{}" + == "You are a helpful text extraction engine. Extract the a, b being referred to in the text. I expect 2 items exactly. No capitalization. No explanations. Return the response exclusively in a pipe separated list, and no headers. The answer is based on the following text:\n{}" ) shutil.rmtree("_test_cache", ignore_errors=True) @@ -22,7 +22,7 @@ def test_extract_dict(): x = df.llm.extract("x", dict(a="one", b="two")) assert ( x["extract"][0] - == "You are a helpful text extraction engine. Extract the one, two being referred to on the text. I expect 2 items exactly. No capitalization. No explanations. Return the response exclusively in a pipe separated list, and no headers. The answer is based on the following text:\n{}" + == "You are a helpful text extraction engine. Extract the one, two being referred to in the text. I expect 2 items exactly. No capitalization. No explanations. Return the response exclusively in a pipe separated list, and no headers. The answer is based on the following text:\n{}" ) shutil.rmtree("_test_cache", ignore_errors=True) @@ -33,7 +33,7 @@ def test_extract_one(): x = df.llm.extract("x", labels="a") assert ( x["extract"][0] - == "You are a helpful text extraction engine. Extract the a being referred to on the text. I expect 1 item exactly. No capitalization. No explanations. The answer is based on the following text:\n{}" + == "You are a helpful text extraction engine. Extract the a being referred to in the text. I expect 1 item exactly. No capitalization. No explanations. The answer is based on the following text:\n{}" ) shutil.rmtree("_test_cache", ignore_errors=True) diff --git a/python/tests/test_sentiment.py b/python/tests/test_sentiment.py index 1982acf..b6bd0fe 100644 --- a/python/tests/test_sentiment.py +++ b/python/tests/test_sentiment.py @@ -38,7 +38,7 @@ def test_sentiment_prompt(): x = df.llm.sentiment("x") assert ( x["sentiment"][0] - == "You are a helpful sentiment engine. Return only one of the following answers: positive, negative, neutral . No capitalization. No explanations. The answer is based on the following text:\n{}" + == "You are a helpful sentiment engine. Return only one of the following answers: positive, negative, neutral . No capitalization. No explanations. The answer is based on the following text:\n{}" ) shutil.rmtree("_test_cache", ignore_errors=True) diff --git a/python/tests/test_summarize.py b/python/tests/test_summarize.py index 5efc3c4..68c72cb 100644 --- a/python/tests/test_summarize.py +++ b/python/tests/test_summarize.py @@ -11,7 +11,7 @@ def test_summarize_prompt(): x = df.llm.summarize("x") assert ( x["summary"][0] - == "You are a helpful summarization engine. Your answer will contain no no capitalization and no explanations. Return no more than 10 words. The answer is the summary of the following text:\n{}" + == "You are a helpful summarization engine. Your answer will contain no capitalization and no explanations. Return no more than 10 words. The answer is based on the following text:\n{}" ) shutil.rmtree("_test_cache", ignore_errors=True) @@ -22,7 +22,7 @@ def test_summarize_max(): x = df.llm.summarize("x", max_words=5) assert ( x["summary"][0] - == "You are a helpful summarization engine. Your answer will contain no no capitalization and no explanations. Return no more than 5 words. The answer is the summary of the following text:\n{}" + == "You are a helpful summarization engine. Your answer will contain no capitalization and no explanations. Return no more than 5 words. The answer is based on the following text:\n{}" ) shutil.rmtree("_test_cache", ignore_errors=True) diff --git a/python/tests/test_translate.py b/python/tests/test_translate.py index 88a251c..87b33eb 100644 --- a/python/tests/test_translate.py +++ b/python/tests/test_translate.py @@ -11,7 +11,7 @@ def test_translate_prompt(): x = df.llm.translate("x", language="spanish") assert ( x["translation"][0] - == "You are a helpful translation engine. You will return only the translation text, no explanations. The target language to translate to is: spanish. The answer is the translation of the following text:\n{}" + == "You are a helpful translation engine. You will return only the translation text, no explanations. The target language to translate to is: spanish. The answer is based on the following text:\n{}" ) shutil.rmtree("_test_cache", ignore_errors=True) diff --git a/r/DESCRIPTION b/r/DESCRIPTION index a46f57d..229d593 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -1,7 +1,7 @@ Package: mall Title: Run Multiple Large Language Model Predictions Against a Table, or Vectors -Version: 0.1.0.9003 +Version: 0.1.0.9004 Authors@R: c( person("Edgar", "Ruiz", , "edgar@posit.co", role = c("aut", "cre")), person(given = "Posit Software, PBC", role = c("cph", "fnd")) diff --git a/r/NAMESPACE b/r/NAMESPACE index 3903ec9..2d2244d 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -10,7 +10,9 @@ S3method(llm_summarize,"tbl_Spark SQL") S3method(llm_summarize,data.frame) S3method(llm_translate,data.frame) S3method(llm_verify,data.frame) +S3method(m_backend_prompt,mall_ellmer) S3method(m_backend_prompt,mall_llama3.2) +S3method(m_backend_prompt,mall_ollama) S3method(m_backend_prompt,mall_session) S3method(m_backend_submit,mall_ellmer) S3method(m_backend_submit,mall_ollama) diff --git a/r/R/m-backend-prompt.R b/r/R/m-backend-prompt.R index b0e09d3..d38263f 100644 --- a/r/R/m-backend-prompt.R +++ b/r/R/m-backend-prompt.R @@ -4,10 +4,31 @@ m_backend_prompt <- function(backend, additional) { UseMethod("m_backend_prompt") } +#' @export +m_backend_prompt.mall_ollama <- function(backend, additional = "") { + next_method <- NextMethod() + additional <- glue(paste( + additional, + "The answer is based on the following text:\n{{x}}" + )) + next_method +} + +#' @export +m_backend_prompt.mall_ellmer <- function(backend, additional = "") { + next_method <- NextMethod() + additional <- glue(paste( + additional, + "The answer will be based on each individual prompt.", + "Treat each prompt as unique when deciding the answer." + )) + next_method +} + #' @export m_backend_prompt.mall_llama3.2 <- function(backend, additional = "") { - base_method <- NextMethod() - base_method$extract <- function(labels) { + next_method <- NextMethod() + next_method$extract <- function(labels) { no_labels <- length(labels) col_labels <- paste0(labels, collapse = ", ") plural <- ifelse(no_labels > 1, "s", "") @@ -16,41 +37,29 @@ m_backend_prompt.mall_llama3.2 <- function(backend, additional = "") { "Return the response exclusively in a pipe separated list, and no headers. ", "" ) - list( - list( - role = "user", - content = glue(paste( - "You are a helpful text extraction engine.", - "Extract the {col_labels} being referred to in the text.", - "I expect {no_labels} item{plural} exactly.", - "No capitalization. No explanations.", - "{text_multi}", - "{additional}", - "The answer is based on the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful text extraction engine.", + "Extract the {col_labels} being referred to in the text.", + "I expect {no_labels} item{plural} exactly.", + "No capitalization. No explanations.", + "{text_multi}", + "{additional}" + )) } - base_method$classify <- function(labels) { + next_method$classify <- function(labels) { labels <- process_labels( x = labels, if_character = "Determine if the text refers to one of the following: {x}", if_formula = "If it classifies as {f_lhs(x)} then return {f_rhs(x)}" ) - list( - list( - role = "user", - content = glue(paste( - "You are a helpful classification engine.", - "{labels}.", - "No capitalization. No explanations.", - "{additional}", - "The answer is based on the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful classification engine.", + "{labels}.", + "No capitalization. No explanations.", + "{additional}" + )) } - base_method + next_method } #' @export @@ -62,32 +71,20 @@ m_backend_prompt.mall_session <- function(backend, additional = "") { if_character = "Return only one of the following answers: {x}", if_formula = "- If the text is {f_lhs(x)}, return {f_rhs(x)}" ) - list( - list( - role = "user", - content = glue(paste( - "You are a helpful sentiment engine.", - "{options}.", - "No capitalization. No explanations.", - "{additional}", - "The answer is based on the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful sentiment engine.", + "{options}.", + "No capitalization. No explanations.", + "{additional}" + )) }, summarize = function(max_words) { - list( - list( - role = "user", - content = glue(paste( - "You are a helpful summarization engine.", - "Your answer will contain no capitalization and no explanations.", - "Return no more than {max_words} words.", - "{additional}", - "The answer is the summary of the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful summarization engine.", + "Your answer will contain no capitalization and no explanations.", + "Return no more than {max_words} words.", + "{additional}" + )) }, classify = function(labels) { labels <- process_labels( @@ -95,18 +92,12 @@ m_backend_prompt.mall_session <- function(backend, additional = "") { if_character = "Determine if the text refers to one of the following: {x}", if_formula = "- For {f_lhs(x)}, return {f_rhs(x)}" ) - list( - list( - role = "user", - content = glue(paste( - "You are a helpful classification engine.", - "{labels}.", - "No capitalization. No explanations.", - "{additional}", - "The answer is based on the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful classification engine.", + "{labels}.", + "No capitalization. No explanations.", + "{additional}" + )) }, extract = function(labels) { no_labels <- length(labels) @@ -117,49 +108,31 @@ m_backend_prompt.mall_session <- function(backend, additional = "") { "Return the response in a simple list, pipe separated, and no headers. ", "" ) - list( - list( - role = "user", - content = glue(paste( - "You are a helpful text extraction engine.", - "Extract the {col_labels} being referred to in the text.", - "I expect {no_labels} item{plural} exactly.", - "No capitalization. No explanations.", - "{text_multi}", - "{additional}", - "The answer is based on the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful text extraction engine.", + "Extract the {col_labels} being referred to in the text.", + "I expect {no_labels} item{plural} exactly.", + "No capitalization. No explanations.", + "{text_multi}", + "{additional}" + )) }, translate = function(language) { - list( - list( - role = "user", - content = glue(paste( - "You are a helpful translation engine.", - "You will return only the translation text, no explanations.", - "The target language to translate to is: {language}.", - "{additional}", - "The answer is the summary of the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful translation engine.", + "You will return only the translation text, no explanations.", + "The target language to translate to is: {language}.", + "{additional}" + )) }, verify = function(what, labels) { - list( - list( - role = "user", - content = glue(paste( - "You are a helpful text analysis engine.", - "Determine if this is true ", - "'{what}'.", - "No capitalization. No explanations.", - "{additional}", - "The answer is based on the following text:\n{{x}}" - )) - ) - ) + glue(paste( + "You are a helpful text analysis engine.", + "Determine if this is true ", + "'{what}'.", + "No capitalization. No explanations.", + "{additional}" + )) } ) } diff --git a/r/R/m-backend-submit.R b/r/R/m-backend-submit.R index 3a0dd00..d132df7 100644 --- a/r/R/m-backend-submit.R +++ b/r/R/m-backend-submit.R @@ -19,6 +19,9 @@ m_backend_submit <- function(backend, x, prompt, preview = FALSE) { #' @export m_backend_submit.mall_ollama <- function(backend, x, prompt, preview = FALSE) { + prompt <- list( + list(role = "user", content = prompt) + ) if (preview) { x <- head(x, 1) map_here <- map @@ -89,22 +92,17 @@ m_ollama_tokens <- function() { #' @export m_backend_submit.mall_ellmer <- function(backend, x, prompt, preview = FALSE) { - # Treats prompt as a system prompt - system_prompt <- prompt[[1]][["content"]] - system_prompt <- glue(system_prompt, x = "") - # Returns two expressions if on preview: setting the system prompt and the - # first chat call if (preview) { return( exprs( - ellmer_obj$set_system_prompt(!!system_prompt), + ellmer_obj$set_system_prompt(!!prompt), ellmer_obj$chat(as.list(!!head(x, 1))) ) ) } ellmer_obj <- backend[["args"]][["ellmer_obj"]] if (m_cache_use()) { - hashed_x <- map(x, function(x) hash(c(ellmer_obj, system_prompt, x))) + hashed_x <- map(x, function(x) hash(c(ellmer_obj, prompt, x))) from_cache <- map(hashed_x, m_cache_check) null_cache <- map_lgl(from_cache, is.null) x <- x[null_cache] @@ -112,14 +110,14 @@ m_backend_submit.mall_ellmer <- function(backend, x, prompt, preview = FALSE) { from_llm <- NULL if (length(x) > 0) { temp_ellmer <- ellmer_obj$clone()$set_turns(list()) - temp_ellmer$set_system_prompt(system_prompt) + temp_ellmer$set_system_prompt(prompt) from_llm <- parallel_chat_text(temp_ellmer, as.list(x)) } if (m_cache_use()) { walk( seq_along(from_llm), function(y) { - m_cache_record(list(system_prompt, x[y]), from_llm[y], hashed_x[y]) + m_cache_record(list(prompt, x[y]), from_llm[y], hashed_x[y]) } ) res <- rep("", times = length(null_cache)) @@ -127,8 +125,9 @@ m_backend_submit.mall_ellmer <- function(backend, x, prompt, preview = FALSE) { res[!null_cache] <- from_cache[!null_cache] res } else { - from_llm + res <- from_llm } + map_chr(res, ~.x) } # Using a function so that it can be mocked in testing diff --git a/r/R/m-defaults.R b/r/R/m-defaults.R index 383850f..0749658 100644 --- a/r/R/m-defaults.R +++ b/r/R/m-defaults.R @@ -13,9 +13,9 @@ m_defaults_set <- function(...) { sub_model <- NULL } obj_class <- clean_names(c( + defaults[["backend"]], model, sub_model, - defaults[["backend"]], "session" )) .env_llm$defaults <- defaults diff --git a/r/R/m-vec-prompt.R b/r/R/m-vec-prompt.R index ba69590..03f390e 100644 --- a/r/R/m-vec-prompt.R +++ b/r/R/m-vec-prompt.R @@ -18,19 +18,6 @@ m_vec_prompt <- function(x, fn <- defaults[[prompt_label]] prompt <- fn(...) } - # If the prompt is a character, it will convert it to - # a list so it can be processed - if (!inherits(prompt, "list")) { - p_split <- strsplit(prompt, "\\{\\{x\\}\\}")[[1]] - if (length(p_split) == 1 && p_split == prompt) { - content <- glue("{prompt}\n{{x}}") - } else { - content <- prompt - } - prompt <- list( - list(role = "user", content = content) - ) - } # Submits final prompt to the LLM resp <- m_backend_submit( backend = backend, diff --git a/r/tests/testthat/_snaps/llm-classify.md b/r/tests/testthat/_snaps/llm-classify.md index 980521d..b94a21f 100644 --- a/r/tests/testthat/_snaps/llm-classify.md +++ b/r/tests/testthat/_snaps/llm-classify.md @@ -12,7 +12,7 @@ Code llm_vec_classify("this is a test", c("a", "b"), preview = TRUE) Output - ollamar::chat(messages = list(list(role = "user", content = "You are a helpful classification engine. Determine if the text refers to one of the following: a, b. No capitalization. No explanations. The answer is based on the following text:\nthis is a test")), + ollamar::chat(messages = list(list(role = "user", content = "You are a helpful classification engine. Determine if the text refers to one of the following: a, b. No capitalization. No explanations. The answer is based on the following text:\nthis is a test")), output = "text", model = "llama3.2", seed = 100) # Classify on Ollama works diff --git a/r/tests/testthat/_snaps/llm-custom.md b/r/tests/testthat/_snaps/llm-custom.md index 5e1342a..c658a62 100644 --- a/r/tests/testthat/_snaps/llm-custom.md +++ b/r/tests/testthat/_snaps/llm-custom.md @@ -8,7 +8,7 @@ 2 I regret buying this laptop. It is too slow and the keyboard is too noisy 3 Not sure how to feel about my new washing machine. Great color, but hard to figure .pred - 1 Yes + 1 No 2 No 3 No diff --git a/r/tests/testthat/_snaps/llm-extract.md b/r/tests/testthat/_snaps/llm-extract.md index 452e6f6..68bdb5f 100644 --- a/r/tests/testthat/_snaps/llm-extract.md +++ b/r/tests/testthat/_snaps/llm-extract.md @@ -3,15 +3,7 @@ Code llm_vec_extract("toaster", labels = "product") Output - [[1]] - [[1]]$role - [1] "user" - - [[1]]$content - You are a helpful text extraction engine. Extract the product being referred to in the text. I expect 1 item exactly. No capitalization. No explanations. The answer is based on the following text: - {x} - - + You are a helpful text extraction engine. Extract the product being referred to in the text. I expect 1 item exactly. No capitalization. No explanations. # Extract on Ollama works diff --git a/r/tests/testthat/_snaps/llm-summarize.md b/r/tests/testthat/_snaps/llm-summarize.md index a1b9aeb..86e44eb 100644 --- a/r/tests/testthat/_snaps/llm-summarize.md +++ b/r/tests/testthat/_snaps/llm-summarize.md @@ -25,8 +25,8 @@ 1 This has been the best TV I've ever used. Great screen, and sound. 2 I regret buying this laptop. It is too slow and the keyboard is too noisy 3 Not sure how to feel about my new washing machine. Great color, but hard to figure - .summary - 1 great tv with good features - 2 laptop purchase was a mistake - 3 having mixed feelings about it + .summary + 1 this tv is excellent quality + 2 i regret my laptop purchase + 3 confused about the purchase diff --git a/r/tests/testthat/_snaps/llm-translate.md b/r/tests/testthat/_snaps/llm-translate.md index e589ae2..46bc0ff 100644 --- a/r/tests/testthat/_snaps/llm-translate.md +++ b/r/tests/testthat/_snaps/llm-translate.md @@ -7,8 +7,8 @@ 1 This has been the best TV I've ever used. Great screen, and sound. 2 I regret buying this laptop. It is too slow and the keyboard is too noisy 3 Not sure how to feel about my new washing machine. Great color, but hard to figure - .translation - 1 Esta ha sido la mejor televisión que he utilizado. Gran pantalla y sonido. - 2 Me arrepiento de comprar este portátil. Es demasiado lento y la tecla es demasiado ruidosa. - 3 No estoy seguro de cómo me siento sobre mi nueva lavadora. Un gran color, pero difícil de entender + .translation + 1 Esta ha sido la mejor televisión que he utilizado hasta ahora. Gran pantalla y sonido. + 2 Lo lamento comprar este portátil. Es demasiado lento y el teclado es demasiado ruidoso. + 3 No estoy seguro de cómo sentirme con mi nueva lavadora. Un color grande, pero difícil de entender diff --git a/r/tests/testthat/_snaps/llm-verify.md b/r/tests/testthat/_snaps/llm-verify.md index 33910e0..9267220 100644 --- a/r/tests/testthat/_snaps/llm-verify.md +++ b/r/tests/testthat/_snaps/llm-verify.md @@ -3,7 +3,7 @@ Code llm_vec_verify("this is a test", "a test", preview = TRUE) Output - ollamar::chat(messages = list(list(role = "user", content = "You are a helpful text analysis engine. Determine if this is true 'a test'. No capitalization. No explanations. The answer is based on the following text:\nthis is a test")), + ollamar::chat(messages = list(list(role = "user", content = "You are a helpful text analysis engine. Determine if this is true 'a test'. No capitalization. No explanations. The answer is based on the following text:\nthis is a test")), output = "text", model = "llama3.2", seed = 100) # Verify on Ollama works diff --git a/r/tests/testthat/_snaps/m-backend-submit.md b/r/tests/testthat/_snaps/m-backend-submit.md index 00a92bb..4aaa172 100644 --- a/r/tests/testthat/_snaps/m-backend-submit.md +++ b/r/tests/testthat/_snaps/m-backend-submit.md @@ -29,7 +29,7 @@ content = "this is the prompt")), preview = TRUE) Output [[1]] - ellmer_obj$set_system_prompt("this is the prompt") + ellmer_obj$set_system_prompt(list(list(content = "this is the prompt"))) [[2]] ellmer_obj$chat(as.list("this is x")) diff --git a/r/tests/testthat/test-m-backend-prompt.R b/r/tests/testthat/test-m-backend-prompt.R index 402dc1e..60dca1a 100644 --- a/r/tests/testthat/test-m-backend-prompt.R +++ b/r/tests/testthat/test-m-backend-prompt.R @@ -9,7 +9,7 @@ test_that("Prompt handles list()", { test_text <- "Custom:{prompt}\n{{x}}" expect_equal( llm_vec_custom(x = "new test", prompt = test_text), - list(list(role = "user", content = test_text)) + test_text ) }) @@ -22,6 +22,6 @@ test_that("Prompt handles list()", { y <- m_backend_prompt(backend) y_extract <- y$extract(labels = c("a", "b")) y_classify <- y$classify(labels = c("a" ~ 1, "b" ~ 2)) - expect_false(x_extract[[1]]$content == y_extract[[1]]$content) - expect_false(x_classify[[1]]$content == y_classify[[1]]$content) + expect_false(x_extract == y_extract) + expect_false(x_classify == y_classify) }) diff --git a/r/tests/testthat/test-m-backend-submit.R b/r/tests/testthat/test-m-backend-submit.R index f57d7ba..6b14347 100644 --- a/r/tests/testthat/test-m-backend-submit.R +++ b/r/tests/testthat/test-m-backend-submit.R @@ -52,7 +52,7 @@ test_that("ellmer code is covered", { x = test_txt, prompt = list(list(content = "test")) ), - as.list(test_txt) + test_txt ) expect_snapshot( m_backend_submit(