Drawing with LLMs [2] - 데이터 및 평가지표 탐색

Min-Kyeong·2025년 4월 23일
0

Drawing with LLMs

목록 보기
2/2
post-thumbnail

과제 설명

해당 kaggle 대회는 텍스트 설명을 바탕으로 SVG(Scalable Vector Graphics) 코드를 생성하는 모델을 개발하는 과제이다. 우선 핵심 내용은 다음과 같다.

텍스트 설명을 입력으로 받아 그 설명에 맞는 이미지를 렌더링하는 SVG 코드를 생성하는 것이다. SVG는 XML 형식을 사용하여 2D 그래픽을 표현하는 벡터 이미지 포맷으로, 크기를 조정해도 품질 손실이 없다는 장점이 있다.

데이터

대회의 테스트 데이터는 약 500개의 일상적인 물체와 다양한 영역의 장면 설명으로 구성되어 있음. 이 데이터의 특성과 제공되는 파일은 다음과 같다.

train.csv :

  • landscapes(풍경), abstract(추상), fashion(패션) 카테고리에 속한 대표적인 설명들의 모음
  • 훈련 데이터로 사용할 수 있는 샘플 텍스트 설명들이 포함되어 있음

kaggle_evaluation/test.csv :

  • 예시용 테스트 데이터 (실제 평가용 테스트 데이터는 아님)
  • 모델 제출 전에 로컬에서 테스트 해 볼 수 있도록 제공하는 데이텅;디/
  • test(Model) 함수로 모델이 제대로 작동하는지 시물레이션 가능

Description 데이터의 특성

  • 일반적인 주제 :
    • 모든 설명은 흔하고 일반적인 주제에 관한 것이다.
    • 브랜드명, 상표, 개인 이름 등은 포함되지 않는다.
    • 일반적인 형태로도 사람에 관한 설명은 없다.
  • 카테고리 구성 :
    • 약 12개의 카테고리로 구성되며 landscape , abstract, fashion 3 카테고리는 훈련, 공개 테스트, 비공개 테스트 세트 모두에서 포함된다.
    • 전체 설명의 절반 이상이 위 3가지 공유 카테고리에 속한다.
  • 텍스트의 길이
    • 어떤 설명도 200자를 초과하지 않는다.
    • 평균 길이는 약 50자이다.

평가 방식

이 대회의 평가 방식은 SVG Image Fidelity Score 을 통해서 이루어진다. 이는 주어진 텍스트 설명과 제출된 SVG 코드 간의 일치도를 측정한다.

  1. SVG 코드 제약 조건 검사
  • 10,000바이트 이하의 크기
  • 허용 목록(allowlist)에 있는 SVG 요소와 속성만 사용 가능
  • CSS style 요소는 사용 불가
  • 래스터화된 이미지 데이터나 외부 소스의 데이터 포함 불가
  1. SVG에서 PNG 변환 처리
  • cairosvg Python 라이브러리를 사용하여 SVG를 PNG로 변환
  • 변환된 PNG에 여러 전처리 필터 적용
  1. VQA(Visual Question Answering) 평가
  • PaliGemma 모델을 사용하여 전처리된 PNG 에 VQA 적용
  • 각 설명마다 4개의 이미지 관련 질문 제시 ( 예/아니요 또는 다중 선택 형식)
  • 이 질문들은 렌더링된 이미지가 설명의 특정 부분들을 얼마나 잘 표현했는지 확인
  • TIFA(Text-to-Iamge Faithfulness Evaluation With Question Answering) 방법론 사용
  1. OCR 텍스트 감지 및 점수 조정
  • PNG 이미지에서 텍스트 감지
  • 4글자 이상 감지되면 해당 이미지의 전체 점수에 지수적 패널티 적용
  • 최종 VQA 점수는 OCR 패널티와 TIFA 태스크 점수의 곱
  1. 최종 점수 계산
  • VQA 점수와 미적 점수의 조화 평균으로 최종 점수 산출
  • VQA 점수에 더 높은 가중치(β=2) 부여

추가 제약 조건

  • 설명이 모델에 전달된 후 5분 이내에 SVG 결과 반환 필요
  • 모든 SVG는 9시간 이내에 생성되어야 함.

상위권 노트북 접근 방식 정리

기간이 약 1달 정도 남았기 때문에 이미 대회를 진행하고 있는 사람들의 접근법을 빨리 따라가보도록 한다.

1) Text → Image → SVG

참고한 노트북
https://www.kaggle.com/code/jiazhuang/new-metric-simple-sd-svg-iterative-optimize#Stable-Diffusion--%3E-SVG

  1. model load

import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler

# Load with optimized scheduler and half precision
stable_diffusion_path = kagglehub.model_download("stabilityai/stable-diffusion-v2/pytorch/1/1")

scheduler = DDIMScheduler.from_pretrained(stable_diffusion_path, subfolder="scheduler")

pipe = StableDiffusionPipeline.from_pretrained(
    stable_diffusion_path,
    scheduler=scheduler,
    torch_dtype=torch.float16,  # Use half precision
    safety_checker=None         # Disable safety checker for speed
)

# Move to GPU and apply optimizations
pipe.to(device)
  • GPU 설정
  • Diffusion model load
  • DDIMScheduler (노이즈 제거 단계 제어)
  1. 비트맵 이미지 생성 (Stable Diffusion 사용)
def generate_bitmap(prompt, negative_prompt="", num_inference_steps=20, guidance_scale=15):
        
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps, 
        guidance_scale=guidance_scale,
    ).images[0]
    
    return image
  1. 생성된 이미지 평가

  2. 비트맵 이미지를 SVG 로 변환

2) LLM 활용

참고한 노트북
https://www.kaggle.com/code/pepushi/drawing-with-llm-qwen2-5-32b-instruct-awq

  1. model load

class Model:
    def __init__(self):
        # self.model_path는 변경하지 않음(대회 요구사항)
        self.model_path = kagglehub.model_download('qwen-lm/qwen2.5/Transformers/32b-instruct-awq/1')

        # LLM 인스턴스 초기화
        self.llm = vllm.LLM(
            self.model_path,
            quantization="awq",              # AWQ 양자화 사용
            tensor_parallel_size=2,          # GPU 병렬 처리(오류 시 1로 변경)
            gpu_memory_utilization=0.95,     # GPU 메모리 사용률
            trust_remote_code=True,
            dtype="half",                    # 반정밀도(FP16) 사용
            enforce_eager=True,
            max_model_len=5120,              # 최대 컨텍스트 길이
            disable_log_stats=True
        )
        

Qwen 2.5 32B AWQ 모델을 사용합니다. 이는 대형 언어 모델로, 양자화를 통해 메모리 효율성을 높였습니다.
vLLM 라이브러리를 사용하여 모델을 효율적으로 로드하고 추론합니다.
두 개의 GPU에 병렬로 모델을 분산하여 속도를 높입니다(tensor_parallel_size=2).

  1. 샘플링 파라미터 설정

# 베이스 샘플링 파라미터(RL 등으로 최적화 가능성 있음)
        self.sampling_params = vllm.SamplingParams(
            n=1,
            top_k=30,                  # 상위 30개 토큰만 고려
            top_p=0.85,                # 누적 확률 0.85까지의 토큰만 고려
            temperature=0.75,          # 다양성 조절(높을수록 다양)
            repetition_penalty=1.05,   # 반복 페널티
            skip_special_tokens=False,
            max_tokens=700,            # SVG 코드 생성 최대 토큰 수
        )
        self.tokenizer = self.llm.get_tokenizer()

다양성과 품질 사이의 균형을 조절하는 파라미터를 설정합니다.
temperature=0.75는 적당한 수준의 창의성을 허용합니다.
max_tokens=700은 충분히 상세한 SVG를 생성할 수 있는 길이입니다.

  1. 프롬프트 템플릿
# 프롬프트: 미적 요소를 강조
        self.prompt_template = """Generate a beautiful, aesthetically pleasing SVG code that visually represents the following text description, while respecting these strict constraints:
<constraints>
* **Allowed Elements:** svg, path, circle, rect, ellipse, line, polyline, polygon, g, linearGradient, radialGradient, stop, defs
* **Allowed Attributes:** viewBox, width, height, fill, stroke, stroke-width, d, cx, cy, r, x, y, rx, ry, x1, y1, x2, y2, points, transform, opacity
</constraints>

<example>
<description>"A red circle with a blue square inside"</description>
svg
<svg viewBox="0 0 256 256" width="256" height="256">
  <circle cx="50" cy="50" r="40" fill="red"/>
  <rect x="30" y="30" width="40" height="40" fill="blue"/>
</svg>
</example>

Focus on visually appealing composition, color usage, and (optionally) subtle gradients or transforms. 
Always provide **complete** SVG with nothing omitted, no ellipses.

<description>"{}"</description>
svg
<svg viewBox="0 0 256 256" width="256" height="256">
"""

상세한 프롬프트를 사용하여 모델에게 SVG 생성에 대한 정확한 지시를 제공합니다.
허용된 SVG 요소와 속성을 명시하여 대회 제약 조건을 준수하도록 합니다.
예시를 통해 모델이 출력 형식을 이해하도록 돕습니다.
미적 요소(구성, 색상 사용, 그라디언트, 변형)를 강조합니다.

  1. SVG 추출 및 유효성 검사
# <svg> 추출
    def _parse_svg(self, response: str) -> str:
        matchs = re.findall(r"<svg.*?</svg>", response, re.DOTALL)
        return matchs[-1].strip() if matchs else ""

    # cairosvg 체크
    def _check_svg_valid(self, svg_code: str) -> bool:
        try:
            cairosvg.svg2png(bytestring=svg_code.encode("utf-8"))
            return True
        except:
            return False

정규 표현식을 사용하여 모델 응답에서 SVG 코드를 추출합니다.
cairosvg 라이브러리를 사용하여 SVG의 렌더링 가능 여부를 검사합니다.
이는 생성된 SVG가 실제로 유효한지 확인하는 중요한 단계입니다.

  1. 보상 함수 (품질 평가)
# 보상 함수: 미적 요소 + 컴팩트함 + 다양성 + gradient/transform 등
    def _reward_function(self, svg: str) -> float:
        if not self._check_svg_valid(svg):
            return 0.0

        base_score = 1.0

        # 다양한 형태(태그)
        shape_tags = ["circle", "rect", "ellipse", "line", "polygon", "polyline", "path"]
        used_tags = set()
        shape_count = 0
        for tag in shape_tags:
            c = len(re.findall(f"<{tag}", svg, re.IGNORECASE))
            if c > 0:
                used_tags.add(tag)
            shape_count += c
        diversity_score = len(used_tags)           # 최대 ~7
        shape_count_score = min(shape_count, 20)   # 최대 20

        # 색상 사용 => fill= / stroke= 패턴
        color_matches = re.findall(r'(?:fill|stroke)\s*=\s*"[^"]+"', svg)
        color_count = len(color_matches)
        color_score = min(color_count / 4.0, 3.0)  # 4색마다 +1, 최대 3

        # gradient 또는 defs => 미적 점수 큰 증가
        gradient_defs_score = 0.5 if any(g in svg for g in ["<linearGradient", "<radialGradient", "<defs"]) else 0.0

        # transform => +0.3
        transform_score = 0.3 if re.search(r'transform\s*=', svg, re.IGNORECASE) else 0.0

        # stroke+fill => +0.3
        stroke_fill_score = 0.3 if ('stroke=' in svg and 'fill=' in svg) else 0.0

        # 크기 패널티: 3000~4000까지 허용, 그 이상은 급격히 감점
        size = len(svg.encode("utf-8"))
        if size>4000:
            size_score = 0.3
        elif size>2000:
            size_score = 0.6
        else:
            size_score = 1.0

        total = (
            base_score +
            0.5*diversity_score +
            0.4*shape_count_score +
            color_score +
            gradient_defs_score +
            transform_score +
            stroke_fill_score +
            size_score
        )
        return round(total, 4)

생성된 SVG의 품질을 정량적으로 평가하는 복합 점수 시스템입니다.
다양한 형태(태그)의 사용, 색상 다양성, 그라디언트/변형 사용 등을 고려합니다.
파일 크기에 대한 패널티를 부여하여 10KB 제한 내에서 최적화합니다.
이 점수를 기반으로 여러 생성 결과 중 최상의 SVG를 선택합니다.

  1. 제약 조건 적용
# SVG 제약 조건 정제
    def enforce_constraints(self, svg_string: str) -> str:
        logging.info("Sanitizing SVG...")

        try:
            parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
            root = etree.fromstring(svg_string, parser=parser)
        except etree.ParseError as e:
            logging.error("SVG ParseError => default. %s", e)
            return self.default_svg

        to_remove = []
        for el in root.iter():
            tag_name = etree.QName(el.tag).localname

            # 불허용 태그
            if tag_name not in self.constraints.allowed_elements:
                to_remove.append(el)
                continue

            # 불허용 속성
            rm_attrs = []
            for attr in el.attrib:
                aname = etree.QName(attr).localname
                if aname not in self.constraints.allowed_elements[tag_name] \
                   and aname not in self.constraints.allowed_elements["common"]:
                    rm_attrs.append(attr)
            for a in rm_attrs:
                del el.attrib[a]

            # href 유효성 검사
            for k, v in list(el.attrib.items()):
                if etree.QName(k).localname=="href" and not v.startswith("#"):
                    del el.attrib[k]

            # path 체크
            if tag_name=="path":
                d_val = el.get("d","")
                if not d_val:
                    to_remove.append(el)
                    continue
                path_regex = re.compile(
                    r'^(?:[MmZzLlHhVvCcSsQqTtAa]\s*'
                    r'(?:-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?'
                    r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*'
                    r')?\s*)+$'
                )
                if not path_regex.match(d_val.strip()):
                    to_remove.append(el)

        for r in to_remove:
            if r.getparent() is not None:
                r.getparent().remove(r)

        try:
            cleaned = etree.tostring(root, encoding="unicode")
            return cleaned
        except ValueError as e:
            logging.error("Sanitize error => default. %s", e)
            return self.default_svg

XML 파서를 사용하여 SVG를 구문 분석합니다.
대회 규칙에 따라 허용되지 않은 태그와 속성을 제거합니다.
외부 리소스에 대한 참조를 방지합니다.
path 요소의 문법적 유효성을 검사합니다.
정제 과정에서 오류가 발생하면 기본 SVG를 반환합니다.

  1. SVG 개선(Refinement)
# refine: 미적 요소 추가 강화
    def _refine_svg(self, svg_code: str) -> str:
        refine_prompt = (
            "Here is an SVG code:\n"
            f"{svg_code}\n\n"
            "Please refine or beautify it further (e.g., add subtle gradient, transform, interesting shapes) "
            "while preserving the same main composition. Return only the complete <svg> code, no comments or ellipses."
        )
        msgs = [{"role": "user", "content": refine_prompt}]
        refine_txt = self.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

        resp = self.llm.generate([refine_txt], self.sampling_params, use_tqdm=False)
        if not resp:
            return svg_code
        new_svg = self._parse_svg(resp[0].outputs[0].text)
        if not new_svg:
            return svg_code

        new_svg = self.enforce_constraints(new_svg)
        if self._check_svg_valid(new_svg):
            old_sc = self._reward_function(svg_code)
            new_sc = self._reward_function(new_svg)
            if new_sc>old_sc:
                logging.info(f"Refine success: {old_sc} => {new_sc}")
                return new_svg
        return svg_code

생성된 최상의 SVG를 추가로 개선하는 단계입니다.
모델에게 미적 요소(그라디언트, 변형, 흥미로운 형태)를 추가하도록 요청합니다.
개선된 버전이 원본보다 높은 보상 점수를 받는 경우에만 채택합니다.
이 단계는 SVG의 예술적 품질을 높이는 데 중요합니다.

  1. 강화학습을 통한 파라미터 최적화

# RL: 파라미터 탐색
    def train_with_rl(self, sample_prompts: List[str], search_size: int=5):
        best_params = None
        best_score = -9999.0

        cands = []
        for _ in range(search_size):
            temp = random.uniform(0.6,1.0)
            top_p = random.uniform(0.75,0.99)
            top_k = random.randint(20,70)
            cands.append((temp, top_p, top_k))

        for (temp, t_p, t_k) in cands:
            logging.info(f"Trying param => t={temp}, p={t_p}, k={t_k}")
            sp = vllm.SamplingParams(
                n=1,
                top_k=t_k,
                top_p=t_p,
                temperature=temp,
                repetition_penalty=1.05,
                skip_special_tokens=False,
                max_tokens=700
            )
            ssum = 0.0
            cc = 0
            for pr in sample_prompts:
                txt = self._apply_template(pr)
                outp = self.llm.generate([txt]*2, sp, use_tqdm=False)
                for r in outp:
                    svg_raw = self._parse_svg(r.outputs[0].text)
                    cleaned = self.enforce_constraints(svg_raw)
                    ssum += self._reward_function(cleaned)
                    cc += 1
            avg = ssum/cc if cc>0 else 0.0
            logging.info(f"Avg reward => {avg}")
            if avg>best_score:
                best_score = avg
                best_params = (temp, t_p, t_k)

        if best_params is not None:
            logging.info(f"Best found => {best_params}, score={best_score}")
            (temp, p, k) = best_params
            self.sampling_params = vllm.SamplingParams(
                n=1,
                top_k=k,
                top_p=p,
                temperature=temp,
                repetition_penalty=1.05,
                skip_special_tokens=False,
                max_tokens=700
            )

샘플링 파라미터(temperature, top_p, top_k)를 최적화하는 간단한 RL 접근법입니다.
다양한 파라미터 조합을 시도하고 평균 보상 점수를 측정합니다.
가장 높은 평균 점수를 얻은 파라미터 세트를 선택합니다.
이는 모델이 더 높은 품질의 SVG를 생성하도록 조정하는 메타 최적화입니다.


# predict (실제 SVG 생성)
    def predict(self, description: str) -> str:
        def generate_svg():
            try:
                prompt_txt = self._apply_template(description)
                # 동일 프롬프트로 여러 번 생성
                responses = self.llm.generate([prompt_txt]*num_attempt,
                                              self.sampling_params,
                                              use_tqdm=False)
                cands = []
                for r in responses:
                    raw_svg = self._parse_svg(r.outputs[0].text)
                    if not raw_svg:
                        continue
                    cleaned = self.enforce_constraints(raw_svg)
                    if self._check_svg_valid(cleaned):
                        cands.append(cleaned)

                if not cands:
                    logging.warning("No valid => default return")
                    return self.default_svg

                # 최상 선택
                best_svg = None
                best_sc = -9999.0
                for c in cands:
                    sc = self._reward_function(c)
                    if sc>best_sc:
                        best_sc = sc
                        best_svg = c

                # refine
                refined = self._refine_svg(best_svg)
                ref_sc = self._reward_function(refined)
                if ref_sc>best_sc:
                    return refined
                else:
                    return best_svg

            except Exception as e:
                logging.error("Exception in generate_svg: %s", e)
                return self.default_svg

        return generate_svg()

텍스트 설명을 SVG 코드로 변환합니다.
num_attempt(10)번 동일한 프롬프트로 여러 SVG를 생성합니다.
각 생성물에 대해 제약 조건을 적용하고 유효성을 검사합니다.
보상 함수를 사용하여 가장 높은 점수의 SVG를 선택합니다.
선택된 SVG를 개선하는 시도를 하고, 개선된 버전이 더 나은 경우 사용합니다.
오류 발생 시 기본 SVG를 반환하여 안정성을 확보합니다.

다른 노트북들을 살펴 봤을 때도 전체적인 흐름이 이 2가지로 정리되는 듯하다.

  • Text → Image → SVG
    ( Stable Diffusion 활용 )

  • LLM 활용

0개의 댓글